Final Project Information¶

  • You can get 25 to 50 points

  • Goal of the project is to create more thorough analysis of the chosen dataset than in the previous two smaller projects.

  • The dataset selection is up to you however is has come from image, text or time series domain.

  • Every project must include brief description of the dataset

    • Number of instances, number of classes and class balance
    • Examples of the data,
    • ...
  • What metrics scores you have decided to use

    • e.g. Accuracy, Precision, Recall, F1-score etc.
    • Also state which one of the scores is the most important from your point of view given the class balance, task, ...
  • Try at least 3 different models

    • The first two models (one simple and one more complex) should be built from scratch, i.e. create your own architecture and train the models
    • The third model will employ transfer learning techniques
      • Use any set of pre-trained embedding vectors (GloVe, Word2Vec, FastText etc.) or any transformer-based model (this is optional as it is more advanced approach above this course complexity) or pre-trained network in case of image dataset
    • The project will include hyper parameter tuning so try different batch sizes, optimizers etc. and document everything accordingly
  • Mandatory part of every project is a summary at the end in which you summarize the most interesting insight obtained.

  • Result is a Jupyter Notebook with descriptions included or a PDF report + source codes.

  • Deadline is 20. 4. 2022

In [254]:
%%HTML
<script src="require.js"></script>
In [255]:
import plotly.io as pio
pio.renderers.default = "notebook"

Proměnné projektu¶

In [256]:
import os
import pandas as pd
from enum import Enum
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from gensim.parsing.preprocessing import (
    preprocess_string,
    strip_tags as strip_tags_gensim,
    strip_punctuation as strip_punctuation_gensim,
    strip_multiple_whitespaces as strip_multiple_whitespaces_gensim,
    strip_numeric as strip_numeric_gensim,
    remove_stopwords as remove_stopwords_gensim,
    strip_short as strip_short_gensim,
    stem_text as stem_text_gensim,
)
import nltk
from nltk.stem import WordNetLemmatizer
import plotly.express as px
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from wordcloud import WordCloud
import itertools
from sklearn import preprocessing
import pandas as pd
import time
import seaborn as sns
In [1160]:
nltk.download('wordnet')
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/usp/pro0255/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
Out[1160]:
True
In [6]:
EXPERIMENTS_SAVE_DIRECTORY = ['.', 'experiments']
In [1162]:
DIRECTORY = ['..', 'data', 'gutenberg']
In [1163]:
LOAD_5A3S = ["5Authors", "Sentence3" ]
LOAD_15A3S = ["15Authors", "Sentence3"]
In [1164]:
AUTHORS_FILENAME = ['authors.csv']
DATA_FILENAME = ['data.csv']
In [1165]:
TEXT_COLUMN = "text"
LABEL_COLUMN = "label"
BLANK_DESCRIPTION = "Nada"
PROJECT_CSV_DELIMITER = ";"
In [1166]:
TRAIN_SIZE = 0.75
VALIDATION_SIZE = 0.10
TEST_SIZE = 0.15
In [1167]:
RANDOM_STATE = 7
In [1168]:
NORMALIZE_LABEL = 15000
In [12]:
LOG = 'log.csv'
SUMMAR = 'sum.csv'

Pomocné metody¶

Typy¶

In [1170]:
class PreprocessingType(Enum):
    Default = "Default"
    Lowercase = "Lowercase"
    CaseInterpunction = "CaseInterpunction"
    Raw = "Raw"
    Blank = BLANK_DESCRIPTION

Metody¶

In [1171]:
def create_dataset_from_dataframe(dataframe):
    features, target = dataframe[TEXT_COLUMN], dataframe[LABEL_COLUMN]
    return create_dataset_from_Xy(features, target)


def create_dataset_from_Xy(X, y):
    return tf.data.Dataset.from_tensor_slices((X, y))
In [1172]:
def create_encoder_from_path(path):
    authors = pd.read_csv(path, sep=";")
    ids = authors["AuthorId"].values
    encoder = preprocessing.LabelEncoder()
    encoder.fit(ids)
    return encoder
In [1173]:
class CSVLogger(tf.keras.callbacks.Callback):
    def __init__(self, path):
        self.path = os.path.sep.join([path, LOG])
        self.timetaken = time.time()
        self.state = {}

    def on_epoch_end(self, epoch, logs={}):
        logs["time"] = time.time() - self.timetaken
        self.state[epoch] = logs

    def on_train_end(self, logs={}):
        headers = []
        for k, v in self.state.items():
            headers = self.state[k].keys()
            break

        data = {k: self.state[k].values() for k, v in self.state.items()}
        df = pd.DataFrame.from_dict(data, orient="index")
        df.columns = headers
        df.to_csv(self.path, sep=';')
In [1174]:
class Visualizer:    
    def __init__(self):
        pass
        
    def create_max_min_mean_len(self, tuples):
        #d = vis.create_max_min_mean_len([("test", test), ("feste", test)]) test
        res = pd.DataFrame()
        
        for name, data in tuples:
            x = data.copy()
            x['len'] = x[TEXT_COLUMN].apply(len)
            x = x.groupby(LABEL_COLUMN).len.agg(['mean', 'min', 'max'])
            
            together = pd.DataFrame.from_dict(
                {"together": 
                    {
                    'mean': np.mean(x['mean']), 
                    'min': np.min(x['min']), 
                    'max': np.max(x['max'])
                     }
                }, 
                orient="index"
            )
            
            x = pd.concat([together, x])
            x = x.reset_index()
            x = pd.melt(x, id_vars=['index'], var_name='value_type', value_name='value')
            x['df_type'] = name
            
            res = pd.concat([res, x])
            
        return res
    
    def show_mean(self, dataframe):
        return self.show_type(dataframe, 'mean') 
    
    def show_max(self, dataframe):
        return self.show_type(dataframe, 'max') 
    
    def show_min(self, dataframe):
        return self.show_type(dataframe, 'min') 
    
    def show_type(self, dataframe, spe_type):
        return px.bar(dataframe[dataframe.value_type == spe_type], x="df_type", y="value", color='index', barmode='group')
    
    def seq_dist(self, dataframe):
        x = dataframe.copy()
        x['len'] = x[TEXT_COLUMN].apply(len)
        return px.histogram(x, x='len', color='label', title='')
    
    
    def create_all_words(self, dataframe):
        x = dataframe.copy()
        all_words = list(itertools.chain.from_iterable([sentence.split(' ') for sentence in x[TEXT_COLUMN]]))
        dist = nltk.FreqDist(all_words)
        return dist
    
    def generate_top_words(self, dataframe):
        x = dataframe.copy()
        res = {}
        
        for current_label in np.unique(test.label.values):
            subframe = x[x.label == current_label]
            res[current_label] = self.show_top_words(subframe)
            
        res['all'] = self.show_top_words(x)
        
        return res
    
    def show_top_words(self, dataframe, n=10):
        dist = self.create_all_words(dataframe)
        df = pd.DataFrame.from_dict(dict(dist), orient="index").reset_index()
        df.columns = ['word', 'freq']
        df = df.sort_values(by='freq', ascending=False)
        first_n = df.iloc[0:n, :]
        return px.bar(first_n, x="word", y="freq", color='word', title=f"{n} most freq words")
        
    def show_wordcloud(self, wordcloud):
        plt.figure(figsize=[10, 10])        
        plt.axis("off")
        x = plt.imshow(wordcloud, interpolation="bilinear")
        return x
    
    def generate_wordclouds(self, dataframe):
        #result = vis.generate_wordclouds(test) test
        x = dataframe.copy()
        res = {}
        
        for current_label in np.unique(test.label.values):
            subframe = x[x.label == current_label]
            res[current_label] = self.wordcloud(subframe)
            
        res['all'] = self.wordcloud(x)
        
        return res
    
    def wordcloud(self, dataframe, max_words=100):
        x = dataframe.copy()
        current_text = " ".join(x.text.values)
        wordcloud = WordCloud(max_font_size=50, max_words=100, background_color="white").generate(current_text)
        return wordcloud
In [1175]:
def prediction_to_labels(y_pred):
    y_pred = np.argmax(y_pred, axis=1)
    return y_pred
In [1176]:
BLACKLIST = [
    'CHAPTER'
]

class TextPreprocessor:
    def __init__(self) -> None:
        self.strip_short_default = self.create_strip_short_method(3)
        self.lemma_text = self.create_lemma_text()

    def strip_tags(self, text):
        return strip_tags_gensim(text)
    
    def strip_upper_words(self, text):
        return [word for word in text.split(' ') if word.upper() != word]
    
    
    def remove_when_blacklisted(self, text):
        current_text = set(text.split(' '))
        blacklist = set(BLACKLIST)

        l = len(current_text.intersection(blacklist))
    
        if l > 0:
            return ''

        return text
    
    def strip_punctuation(self, text):
        return strip_punctuation_gensim(text)

    def strip_multiple_whitespaces(self, text):
        return strip_multiple_whitespaces_gensim(text)

    def strip_numeric(self, text):
        return strip_numeric_gensim(text)

    def strip_stopwords(self, text):
        return remove_stopwords_gensim(text)

    def strip_short(self, text, minsize=3):
        return strip_short_gensim(text, minsize)

    def strip_short(self, text):
        return strip_short_gensim(text)

    def create_strip_short_method(self, minsize=3):
        # TODO: fix
        print(f"Creating shorting method with min = {minsize}")

        def strip_short(text, minsize=minsize):
            return strip_short_gensim(text, minsize)

        return strip_short

    def stem_text(self, text):
        return stem_text_gensim(text)

    def to_lowercase(self, text):
        return text.lower()

    def create_lemma_text(self):
        instance = WordNetLemmatizer()
        print(f"Creating lemma method with instance {instance}")

        def lemma_text(text):
            word_list = nltk.word_tokenize(text)
            return " ".join([instance.lemmatize(word) for word in word_list])

        return lemma_text

    def create_preprocess_string_func(self, filters, tokenized=False):
        def preprocess_func(text):
            result = preprocess_string(text, filters)
            return result if tokenized else " ".join(result)

        return preprocess_func

    def default_preprocessing(self):
        return self.create_preprocess_string_func(
            [
                self.remove_when_blacklisted,
                self.to_lowercase,
                self.strip_punctuation,
                self.strip_tags,
                self.strip_multiple_whitespaces,
                self.strip_numeric,
                self.strip_stopwords,
                self.strip_short,
                self.lemma_text,
            ]
        )
    
    def default_lowerinterpunction(self):
        return self.create_preprocess_string_func(
            [
                self.remove_when_blacklisted,
                self.to_lowercase,
                self.strip_punctuation,
                self.strip_multiple_whitespaces,
                self.strip_numeric
            ]
        )
In [1177]:
class PreprocessingFactory:
    def __init__(self) -> None:
        self.preprocessor = TextPreprocessor()
        self.build_dic()

    def build_dic(self):
        self.dic = {
            PreprocessingType.Default: self.preprocessor.default_preprocessing(),
            PreprocessingType.Lowercase: self.preprocessor.create_preprocess_string_func(
                [self.preprocessor.to_lowercase]
            ),
            PreprocessingType.Raw: lambda x: x,
            PreprocessingType.CaseInterpunction: self.preprocessor.default_lowerinterpunction(),
            PreprocessingType.Blank: None,
        }

    def create(self, preprocessing_type):
        return self.dic[preprocessing_type]
In [1178]:
def get_load_path(directory, combination, filename):
    return os.path.sep.join(directory + combination + filename)

def get_load_path_53(filename=DATA_FILENAME):
    return get_load_path(DIRECTORY, LOAD_5A3S, filename)

def get_load_path_153(filename=DATA_FILENAME):
    return get_load_path(DIRECTORY, LOAD_15A3S, filename)

def load_dataset_from_path_with_normalization(path, normalize=None, preprocessing_type=None):
    factory = PreprocessingFactory()
    
    normalize_final = None
    
    if normalize is not None:
        print('Specified normalize method')
        normalize_final = normalize
    else:
        print(f'Specified type {preprocessing_type.value}')
        if preprocessing_type is None:
            normalize_final = factory.create(PreprocessingType.Default)
        else:
            normalize_final = factory.create(preprocessing_type)

    dataset = load_dataset_from_path(path)
    dataset[TEXT_COLUMN] = dataset[TEXT_COLUMN].apply(normalize_final)
    return dataset


def load_dataset_from_path(path):
    dataset = pd.read_csv(path, sep=PROJECT_CSV_DELIMITER, header=None)
    dataset.columns = [TEXT_COLUMN, LABEL_COLUMN]
    return dataset


def normalize_dataframe_to_size(dataframe, size):
    all_labels = dataframe[LABEL_COLUMN].unique()
    new_dataframe = pd.DataFrame()

    for label in all_labels:
        selected_dataframe = dataframe[dataframe.label == label].sample(size, random_state=RANDOM_STATE)
        new_dataframe = pd.concat([new_dataframe, selected_dataframe])

    return shuffle(new_dataframe, random_state=RANDOM_STATE)


def split_dataframe_to_train_test_valid(dataframe, test_size=TEST_SIZE, valid_size=VALIDATION_SIZE):
    features, target = dataframe[TEXT_COLUMN], dataframe[LABEL_COLUMN]
    X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=RANDOM_STATE)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=valid_size, random_state=RANDOM_STATE)
    
    print(f"Train {X_train.shape}")
    print(f"Valid {X_valid.shape}")
    print(f"Test {X_test.shape}")
    
    return X_train, X_valid, X_test, y_train, y_valid, y_test
In [1179]:
test = load_dataset_from_path_with_normalization(get_load_path_53(), None, PreprocessingType.Raw)
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Raw
In [1180]:
test = normalize_dataframe_to_size(test, NORMALIZE_LABEL)
In [1181]:
X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(test)
Train (57375,)
Valid (6375,)
Test (11250,)

Popis datové sady¶

Datová sada je vyrobna vlastní sílou. Jde o textová data, která se parují s autorem těchto dat. Z hlediska textové stylu tento text můžeme zařadit do uměleckého stylu.

Krátký význam datové sady¶

Motivací proč datová sada byla vytvořena je pokusit se identifikovat autora textu. Tento problém je určení autorství textu. Pakliže se na to budeme dívat skrz oči datového analytika, pak je to převlečený klasifikační problém, kde se snažíme předikonvat n tříd. Zde jako třídu myslíme autora textu.

Co je to Projekt Gutenberg?¶

Projekt Gutenberg je dobrovolnická snaha digitalizovat, archivovat a distribuovat kulturní díla. Byl založen v roce 1971 a je nejstarší digitální knihovnou. Většina děl jsou plné texty knih se statusem volného díla.

Tvorba¶

Jak již bylo zmíněno, datová sada byla vytvořena částečně vlastní sílou a částečně s výpomocí některých veřejně dostupných knihoven nalezených na internetu.

V mém případě jsem využil knihovnu do R (https://cran.r-project.org/web/packages/gutenbergr/vignettes/intro.html), která poskytuje jednoduché API pro stáhnutí všech děl z Projektu Gutenberg.

Před samotným stažením byla byla všechna dostupná díla vyfiltrovaná dle anglického jazyka, a tak aby měla autora. Všechna tato díla byla stažena ve formátu json, tak aby z nich pak šlo jednoduše vytvořit potřebný dataset. Příklad takové json souboru lze najít v adresáři example.

Poté co se mi podařilo úspěšně stáhnut všechna díla byly vytvořeny datové sady. Vždy byl specifikován počet autorů a velikost textové části. Pro představu pokud dané dílo bylo od autora námi požadovaného, pak celé dílo bylo rozsekáno na textové části o n větách. Kde věta byla ukončena dle typických znaků jako jsou {., !, ?}. Takto byly vytvořeny záznamy, které reprezentují strukturu datové sady.

Záznamy datové sady¶

Struktura datové sady:

  • text - textová data o n větách
  • label - id autora

V tomto projektu do předmětu byly zpracovány datové sady o 5 a 10 autorech s velikosti textové sady 3.

Poznatky¶

  • Díla autorů jsou různě dlouhá. Po rozsekání na úseky mohl být určitý autor ve velké výhodě.
  • Pokud by bylo pracováno z celou vytvořenou datovou sadou, pak běh modelů se předpokládá jako časově vysoce náročný. Z tohoto důvodu byla data zmenšena na menší počet záznamů.
  • Zmenšením datové sady jsme se zbavili nevyváženosti a vytvořili si tak vyváženou datovou sadou, s kterou se pracuje jednodušeji.
  • Stažená umělecká díla zabíraly kolem 12GB.
  • Velikost vytvořené datové sady se lišila. Bývala závislá na počtu autorů, velikosti textové části.

Představení datové sady¶

Zde bude představena datová sada blíže.

  • Vizualizace v dataframu.
  • Grafy.
  • Potvrzení dat výše zmiňovaných.
  • ...

Ukázka dat¶

In [1182]:
data_test = load_dataset_from_path_with_normalization(get_load_path_53(), lambda x: x) 
authors_test = pd.read_csv(get_load_path_53(AUTHORS_FILENAME), sep=';')
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified normalize method
In [1183]:
data_test.head()
Out[1183]:
text label
0 THE TRAGEDY OF PUDD'NHEAD WILSON by Mark Twai... 53
1 These chapters are right, now, in every detail... 53
2 Given under my hand this second day of January... 53
3 In 1830 it was a snug collection of modest one... 53
4 Then that house was complete, and its contentm... 53

Každá datová sada byla vytvořena s doprovodným csv souborem obsajujícím další data k autorovi. Jak zde můžeme vidět, víme jak se autor jmenoval.

In [1184]:
authors_test.head()
Out[1184]:
AuthorId Author
0 761 Lytton, Edward Bulwer Lytton, Baron
1 1800 Ebers, Georg
2 53 Twain, Mark
3 8659 Kingston, William Henry Giles
4 1285 Parker, Gilbert
In [1185]:
label_counts = data_test.groupby(by=["label"]).size().reset_index(name="counts")
label_counts.head()
Out[1185]:
label counts
0 53 74224
1 761 144504
2 1285 133159
3 1800 107653
4 8659 140495

Lze pozorovat nevyváženost datové sady, a proto bude vybrán z každé třídy pouze určitý počet záznamů.

  • Zmenšení datové sady za účelem urychlení.
  • Zajištění vyváženosti.
In [1186]:
px.bar(data_frame=label_counts, y="counts", barmode="group")
In [1187]:
norm_data_test = normalize_dataframe_to_size(data_test, NORMALIZE_LABEL)
label_counts = norm_data_test.groupby(by=["label"]).size().reset_index(name="counts")
px.bar(data_frame=label_counts, y="counts")

Předzpracování textových dat a jejich vizualizace¶

Je dobrý přístup textová data před samotným využitím předzpracovat. Toto předzpracování pro příklad zajištuje:

  • Transformaci do malých písmen.
  • Transformaci do základního tvaru.
  • Odstranění šumu, příklad tweety většinou mívají tendenci v sobě obsahovat reference na jiné uživatele v síti.
  • ...

Díky toho získáme některé žadoucí výsledky:

  • Zmenšení velikosti dat, tudíž rychlejší následné zpracování.
  • Extrakce relevantních slov, feature.
  • ...
In [1188]:
vis = Visualizer()
In [1189]:
raw_data = normalize_dataframe_to_size(
    load_dataset_from_path_with_normalization(
        get_load_path_53(), 
        None, 
        PreprocessingType.Raw), 
    NORMALIZE_LABEL
)
lowered_data = normalize_dataframe_to_size(
    load_dataset_from_path_with_normalization(
        get_load_path_53(), 
        None, 
        PreprocessingType.Lowercase
    ), 
    NORMALIZE_LABEL
)
default_data = normalize_dataframe_to_size(
    load_dataset_from_path_with_normalization(
        get_load_path_53(), 
        None, 
        PreprocessingType.Default), 
    NORMALIZE_LABEL
)
lowerinterpunction_data = normalize_dataframe_to_size(
    load_dataset_from_path_with_normalization(
        get_load_path_53(), 
        None, 
        PreprocessingType.CaseInterpunction), 
    NORMALIZE_LABEL
)
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Raw
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Lowercase
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Default
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type CaseInterpunction
In [1190]:
print(len(raw_data))
raw_data.head()
75000
Out[1190]:
text label
182816 Thrice a deep breath the knight relieved did d... 761
318651 But they were not cool deeps by any means, for... 53
285365 retorted Klea. "One of her escort threw them t... 1800
322281 The prince is an educated gentleman. His cultu... 53
163774 I thought: Now, this is the man whom I saw twe... 53
In [1191]:
print(len(lowered_data))
lowered_data.head()
75000
Out[1191]:
text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means, for... 53
285365 retorted klea. "one of her escort threw them t... 1800
322281 the prince is an educated gentleman. his cultu... 53
163774 i thought: now, this is the man whom i saw twe... 53
In [1192]:
print(len(lowerinterpunction_data))
lowerinterpunction_data.head()
75000
Out[1192]:
text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means for ... 53
285365 retorted klea one of her escort threw them to ... 1800
322281 the prince is an educated gentleman his cultur... 53
163774 i thought now this is the man whom i saw twent... 53
In [1193]:
print(len(default_data))
default_data.head()
75000
Out[1193]:
text label
182816 thrice deep breath knight relieved draw fair v... 761
318651 cool deep mean sun ray weltering hot little br... 53
285365 retorted klea escort threw drop subject pray w... 1800
322281 prince educated gentleman culture european eur... 53
163774 thought man saw year ago occasion went free ha... 53

Porovnání vybraného textu¶

In [1194]:
selected_index = 76
In [1195]:
raw_data.text.values[selected_index]
Out[1195]:
'It was the morning of the twentieth day. At noon we would reach Carson City, the capital of Nevada Territory. We were not glad, but sorry.'
In [1196]:
lowered_data.text.values[selected_index]
Out[1196]:
'it was the morning of the twentieth day. at noon we would reach carson city, the capital of nevada territory. we were not glad, but sorry.'
In [1197]:
default_data.text.values[selected_index]
Out[1197]:
'morning twentieth day noon reach carson city capital nevada territory glad sorry'
In [1198]:
lowerinterpunction_data.text.values[selected_index]
Out[1198]:
'it was the morning of the twentieth day at noon we would reach carson city the capital of nevada territory we were not glad but sorry'

Vizualizace max, min, mean délky textu¶

In [1199]:
max_min_mean = vis.create_max_min_mean_len([
    ("RAW", raw_data), 
    ("DEFAULT", default_data), 
    ("LOWER", lowered_data),
    ("LOWERINTEPUNCTION", lowerinterpunction_data)
])
In [1200]:
vis.show_max(max_min_mean)

Zde můžeme pozorovat, že default předzpracování zmenší délku věty na polovinu. Modely nám z tohoto důvodu poběží rychleji a přepokládáme, že tímto předzpracování jsme zároveň nechali nejpodstatnější informace v textových datech.

Toto ale nemusí být pravdou a ověření najdeme ve výsledcích.

In [1201]:
vis.show_mean(max_min_mean)
In [1202]:
vis.show_min(max_min_mean)

Vizualizace četnosti¶

Raw¶

In [1203]:
vis.seq_dist(raw_data)

Default¶

In [1204]:
vis.seq_dist(default_data)

Lower¶

In [1205]:
vis.seq_dist(lowered_data)

LowerInterpunction¶

In [1206]:
vis.seq_dist(lowerinterpunction_data)

Vizualizace obláčků¶

Raw¶

In [1207]:
raw_clouds = vis.generate_wordclouds(raw_data)
In [1208]:
vis.show_wordcloud(raw_clouds['all'])
Out[1208]:
<matplotlib.image.AxesImage at 0x7fe86c20e1f0>

Default¶

In [1209]:
default_clouds = vis.generate_wordclouds(default_data)
In [1210]:
vis.show_wordcloud(default_clouds['all'])
Out[1210]:
<matplotlib.image.AxesImage at 0x7fe7c009d3a0>

Lower¶

In [1211]:
lower_clouds = vis.generate_wordclouds(lowered_data)
In [1212]:
vis.show_wordcloud(lower_clouds['all'])
Out[1212]:
<matplotlib.image.AxesImage at 0x7fe8b4e32730>

LowerInterpunction¶

In [1213]:
lowerinterpunction_clouds = vis.generate_wordclouds(lowerinterpunction_data)
In [1214]:
vis.show_wordcloud(lowerinterpunction_clouds['all'])
Out[1214]:
<matplotlib.image.AxesImage at 0x7fe7645de3d0>
In [1215]:
vis.show_wordcloud(lowerinterpunction_clouds[53])
Out[1215]:
<matplotlib.image.AxesImage at 0x7fe86c229b50>
In [1216]:
vis.show_wordcloud(lowerinterpunction_clouds[8659])
Out[1216]:
<matplotlib.image.AxesImage at 0x7fe7c0758a90>

Vizualizace nejvíce četných slov¶

Raw¶

In [1217]:
raw_top_words = vis.generate_top_words(raw_data)
In [1218]:
raw_top_words.keys()
Out[1218]:
dict_keys([53, 761, 1285, 1800, 8659, 'all'])
In [1219]:
raw_top_words[53]

Default¶

Zde můžeme pozorovat nejzajímavější rozdíl mezi četností jednotlivých slov. Každý z autorů většinou mívá jiná nejvíce používaná slova. Toto může být klíčové pro správnou predikci daného autora.

  • "Jeho nejoblíbenější slovíčko?"
  • "Nejoblíbenější předložka."
  • ...
In [1220]:
default_top_words = vis.generate_top_words(default_data)
In [1221]:
default_top_words[53]
In [1222]:
default_top_words[8659]
In [1223]:
default_top_words[1800]

Lower¶

In [1224]:
lower_top_words = vis.generate_top_words(lowered_data)
In [1225]:
lower_top_words[53]

LowerInterpunction¶

In [1226]:
lowerinterpunction_top_words = vis.generate_top_words(lowerinterpunction_data)
In [1227]:
lowerinterpunction_top_words[53]

Výběr metriky¶

Jak již jsme zmínili výše, datová sada byla upravena tak, ať pracujeme s vyváženou. Tudíž každý autor má obsaženo stejný počet záznamů. Z toho důvodu bude vybraná metrika přesnost (accuracy).

Tato metrika nám poskytne procentuální hodnotu, která říká jak přesně jsme schopni predikovat, že tento text napsal tento autor.

Modely¶

Připravení dat¶

  • RAW
  • LOWER
  • DEFAULT
  • LOWER_I
In [1228]:
ALL_KEYS = ['RAW', "LOWER", "DEFAULT", "LOWER_I"]
In [1249]:
def load_5():
    return {
         "RAW": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_53(), 
            None, 
            PreprocessingType.Raw), 
        NORMALIZE_LABEL
        ),
        "LOWER": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_53(), 
            None, 
            PreprocessingType.Lowercase
        ), 
        NORMALIZE_LABEL
        ),
        "DEFAULT": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_53(), 
            None, 
            PreprocessingType.Default), 
        NORMALIZE_LABEL
        ),
        "LOWER_I": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_53(), 
            None, 
            PreprocessingType.CaseInterpunction), 
        NORMALIZE_LABEL
        )
    }
In [1257]:
def load_15():
    return {
        "RAW":  normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_153(), 
            None, 
            PreprocessingType.Raw), 
            NORMALIZE_LABEL
        ),
        "LOWER": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_153(), 
            None, 
            PreprocessingType.Lowercase
        ), 
        NORMALIZE_LABEL
        ),
        "DEFAULT": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_153(), 
            None, 
            PreprocessingType.Default), 
        NORMALIZE_LABEL
        ),
        "LOWER_I": normalize_dataframe_to_size(
            load_dataset_from_path_with_normalization(
            get_load_path_153(), 
            None, 
            PreprocessingType.CaseInterpunction), 
        NORMALIZE_LABEL
        )
    }   
In [1258]:
data_5 = load_5()
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Raw
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Lowercase
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Default
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type CaseInterpunction
In [1259]:
data_15 = load_15()
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Raw
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Lowercase
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type Default
Creating shorting method with min = 3
Creating lemma method with instance <WordNetLemmatizer>
Specified type CaseInterpunction
In [1260]:
data = {
    "15": data_15,
    "5": data_5
}
In [1254]:
data[]
Out[1254]:
{'5': {'RAW':                                                      text  label
  182816  Thrice a deep breath the knight relieved did d...    761
  318651  But they were not cool deeps by any means, for...     53
  285365  retorted Klea. "One of her escort threw them t...   1800
  322281  The prince is an educated gentleman. His cultu...     53
  163774  I thought: Now, this is the man whom I saw twe...     53
  ...                                                   ...    ...
  52919   They felt that had they done so, they would na...   8659
  168164  *****  To W. D. Howells, in America:          ...     53
  51857   I daresay I shall not succeed at first, but th...   8659
  87638   cried Bill. "We have gained an inch, and in an...   8659
  259728  But that good man forgot not, even over the wi...   1800
  
  [75000 rows x 2 columns],
  'LOWER':                                                      text  label
  182816  thrice a deep breath the knight relieved did d...    761
  318651  but they were not cool deeps by any means, for...     53
  285365  retorted klea. "one of her escort threw them t...   1800
  322281  the prince is an educated gentleman. his cultu...     53
  163774  i thought: now, this is the man whom i saw twe...     53
  ...                                                   ...    ...
  52919   they felt that had they done so, they would na...   8659
  168164  ***** to w. d. howells, in america: kaltenleut...     53
  51857   i daresay i shall not succeed at first, but th...   8659
  87638   cried bill. "we have gained an inch, and in an...   8659
  259728  but that good man forgot not, even over the wi...   1800
  
  [75000 rows x 2 columns],
  'DEFAULT':                                                      text  label
  182816  thrice deep breath knight relieved draw fair v...    761
  318651  cool deep mean sun ray weltering hot little br...     53
  285365  retorted klea escort threw drop subject pray w...   1800
  322281  prince educated gentleman culture european eur...     53
  163774  thought man saw year ago occasion went free ha...     53
  ...                                                   ...    ...
  52919   felt naturally accused influenced vindictive f...   8659
  168164  howells america kaltenleutgeben bei wien aug d...     53
  51857   daresay shall succeed like trying piece open g...   8659
  87638   cried gained inch minute shall gained inch hurrah   8659
  259728  good man forgot wine jar pleasure folk albeit ...   1800
  
  [75000 rows x 2 columns],
  'LOWER_I':                                                      text  label
  182816  thrice a deep breath the knight relieved did d...    761
  318651  but they were not cool deeps by any means for ...     53
  285365  retorted klea one of her escort threw them to ...   1800
  322281  the prince is an educated gentleman his cultur...     53
  163774  i thought now this is the man whom i saw twent...     53
  ...                                                   ...    ...
  52919   they felt that had they done so they would nat...   8659
  168164  to w d howells in america kaltenleutgeben bei ...     53
  51857   i daresay i shall not succeed at first but the...   8659
  87638   cried bill we have gained an inch and in anoth...   8659
  259728  but that good man forgot not even over the win...   1800
  
  [75000 rows x 2 columns]}}
In [1262]:
data['5']
Out[1262]:
{'RAW':                                                      text  label
 182816  Thrice a deep breath the knight relieved did d...    761
 318651  But they were not cool deeps by any means, for...     53
 285365  retorted Klea. "One of her escort threw them t...   1800
 322281  The prince is an educated gentleman. His cultu...     53
 163774  I thought: Now, this is the man whom I saw twe...     53
 ...                                                   ...    ...
 52919   They felt that had they done so, they would na...   8659
 168164  *****  To W. D. Howells, in America:          ...     53
 51857   I daresay I shall not succeed at first, but th...   8659
 87638   cried Bill. "We have gained an inch, and in an...   8659
 259728  But that good man forgot not, even over the wi...   1800
 
 [75000 rows x 2 columns],
 'LOWER':                                                      text  label
 182816  thrice a deep breath the knight relieved did d...    761
 318651  but they were not cool deeps by any means, for...     53
 285365  retorted klea. "one of her escort threw them t...   1800
 322281  the prince is an educated gentleman. his cultu...     53
 163774  i thought: now, this is the man whom i saw twe...     53
 ...                                                   ...    ...
 52919   they felt that had they done so, they would na...   8659
 168164  ***** to w. d. howells, in america: kaltenleut...     53
 51857   i daresay i shall not succeed at first, but th...   8659
 87638   cried bill. "we have gained an inch, and in an...   8659
 259728  but that good man forgot not, even over the wi...   1800
 
 [75000 rows x 2 columns],
 'DEFAULT':                                                      text  label
 182816  thrice deep breath knight relieved draw fair v...    761
 318651  cool deep mean sun ray weltering hot little br...     53
 285365  retorted klea escort threw drop subject pray w...   1800
 322281  prince educated gentleman culture european eur...     53
 163774  thought man saw year ago occasion went free ha...     53
 ...                                                   ...    ...
 52919   felt naturally accused influenced vindictive f...   8659
 168164  howells america kaltenleutgeben bei wien aug d...     53
 51857   daresay shall succeed like trying piece open g...   8659
 87638   cried gained inch minute shall gained inch hurrah   8659
 259728  good man forgot wine jar pleasure folk albeit ...   1800
 
 [75000 rows x 2 columns],
 'LOWER_I':                                                      text  label
 182816  thrice a deep breath the knight relieved did d...    761
 318651  but they were not cool deeps by any means for ...     53
 285365  retorted klea one of her escort threw them to ...   1800
 322281  the prince is an educated gentleman his cultur...     53
 163774  i thought now this is the man whom i saw twent...     53
 ...                                                   ...    ...
 52919   they felt that had they done so they would nat...   8659
 168164  to w d howells in america kaltenleutgeben bei ...     53
 51857   i daresay i shall not succeed at first but the...   8659
 87638   cried bill we have gained an inch and in anoth...   8659
 259728  but that good man forgot not even over the win...   1800
 
 [75000 rows x 2 columns]}

Výsledky, které chceme uložit¶

In [1263]:
BLANK = '-'
In [1264]:
class Fields(Enum):
    ModelName = 'ModelName'
    BatchSize = 'BatchSize'
    Optimizer = 'Optimizer'
    LR = 'LR'
    Epochs = 'Epochs' 
    EmbeddingSize = 'EmbeddingSize'
    Time = 'Time'
    Accuracy = 'Accuracy'
    Hits = 'Hits'
    Miss = 'Miss'
    Key = 'Key'
    SeqLen = 'SeqLen'
    VocabSize = 'VocabSize'
    TrainableEmbedding = 'TrainableEmbedding'
    ConfMatrix = "ConfMatrix"
    ModelType = "Type"
    TransformerName = "TransformerName"
    NumberOfAuthors= "NumberOfAuthors"
In [1265]:
def create_value(
    ModelName=BLANK,
    BatchSize=BLANK,
    Optimizer=BLANK,
    Epochs=BLANK,
    EmbeddingSize=BLANK,
    Time=BLANK,
    Accuracy=BLANK,
    LR=BLANK,
    Hits=BLANK,
    Miss=BLANK,
    Key=BLANK,
    SeqLen=BLANK,
    VocabSize=BLANK,
    TrainableEmbedding=BLANK,
    ConfMatrix=BLANK,
    ModelType=BLANK,
    TransformerName=BLANK,
    NumberOfAuthors=BLANK
):
    return {
        Fields.ModelName.value: ModelName,
        Fields.BatchSize.value: BatchSize,
        Fields.Optimizer.value: Optimizer,
        Fields.LR.value: LR,
        Fields.Epochs.value: Epochs,
        Fields.EmbeddingSize.value: EmbeddingSize,
        Fields.Time.value: Time,
        Fields.Accuracy.value: Accuracy,
        Fields.Hits.value: Hits,
        Fields.Miss.value: Miss,
        Fields.Key.value: Key,
        Fields.SeqLen.value: SeqLen,
        Fields.VocabSize.value: VocabSize,
        Fields.TrainableEmbedding.value: TrainableEmbedding,
        Fields.ConfMatrix.value: ConfMatrix,
        Fields.ModelType.value: ModelType,
        Fields.TransformerName.value: TransformerName,
        Fields.NumberOfAuthors.value: NumberOfAuthors
    }

Definice parametrů¶

In [1266]:
BATCH_SIZE = 32

BATCH_SIZES = [
  #32,
  64,
  #128,
  #256
]

LR = 0.001

TRANSFORMER_LR = [
    0.001,
    5e-5, 
#     4e-5, 
#     3e-5, 
#     2e-5    
]

ADAM = tf.keras.optimizers.Adam
RMS = tf.keras.optimizers.RMSprop

OPTIMIZERS = [
  ADAM,
  RMS  
]

EMB_SIZES = [
  50,
  #100,
  #150,
  #200,
  #250,
  300
]

EPOCHS = 10

LOSS = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

METRICS = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")]
In [1267]:
PATIENCE = 3
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True, mode="auto")
In [1268]:
def setup_directory():    
    path = os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY)
    index = len(os.listdir(os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY))) 
    
    path = os.path.sep.join([path, str(index)])
    
    
    if not os.path.exists(path):
        os.makedirs(path)
        
        
    return path
In [1269]:
def save_experiment_info(
    path,
    ModelName=BLANK,
    BatchSize=BLANK,
    Optimizer=BLANK,
    Epochs=BLANK,
    EmbeddingSize=BLANK,
    Time=BLANK,
    Accuracy=BLANK,
    LR=BLANK,
    Hits=BLANK,
    Miss=BLANK,
    Key=BLANK,
    SeqLen=BLANK,
    VocabSize=BLANK,
    TrainableEmbedding=BLANK,
    ConfMatrix=BLANK,
    ModelType=BLANK,
    TransformerName=BLANK,
    NumberOfAuthors=BLANK
):
    val = create_value(
        ModelName=ModelName,
        BatchSize=BatchSize,
        Optimizer=Optimizer,
        Epochs=BLANK,
        EmbeddingSize=EmbeddingSize,
        Time=Time,
        Accuracy=Accuracy,
        LR=LR,
        Hits=Hits,
        Miss=Miss,
        Key=Key,
        SeqLen=SeqLen,
        VocabSize=VocabSize,
        TrainableEmbedding=TrainableEmbedding,
        ConfMatrix=ConfMatrix,
        ModelType=ModelType,
        TransformerName=TransformerName,
        NumberOfAuthors=NumberOfAuthors
    )
    df = pd.DataFrame.from_dict(val, orient="index")
    path = os.path.sep.join([path, SUMMAR])
    print(f"Saving to {path}")
    df.to_csv(path, sep=';')
    return df
In [1270]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

Jednoduchý vlastnoruční model¶

In [1279]:
def run_dense_model(
    max_tokens,
    output_sequence_length,
    number_of_authors,
    emb_size,
    key,
    loss,
    optimizer,
    metrics,
    batch_size,
    epochs,
    lr
):
    MODEL_NAME = "DENSE"
    current_path = setup_directory()
    
    current_data = data[str(number_of_authors)][key]
    loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
    encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
    
    X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
    
    y_test = encoder.transform(y_test)
    y_train = encoder.transform(y_train)
    y_valid = encoder.transform(y_valid)
    
    train_ds = create_dataset_from_Xy(X_train, y_train)
    test_ds = create_dataset_from_Xy(X_test, y_test)
    valid_ds = create_dataset_from_Xy(X_valid, y_valid)
    
    
    vector_layer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens,
        output_mode='int',
        standardize=None,
        output_sequence_length=output_sequence_length,
        split='whitespace'
    )

    vector_layer.adapt(train_ds.map(lambda x, y: x))
    
    model = tf.keras.Sequential()

    model.add(tf.keras.Input(shape=(1,), dtype=tf.string))

    model.add(vector_layer)

    model.add(
        tf.keras.layers.Embedding(
            max_tokens + 1, 
            emb_size,
            mask_zero = True
        )
    )

    model.add(tf.keras.layers.Flatten())

    model.add(tf.keras.layers.Dense(64, activation='relu'))

    model.add(tf.keras.layers.Dropout(rate=0.2))

    model.add(tf.keras.layers.Dense(32, activation='relu'))

    model.add(tf.keras.layers.Dropout(rate=0.4))

    model.add(tf.keras.layers.Dense(64, activation='relu'))

    model.add(tf.keras.layers.Dropout(rate=0.2))

    model.add(tf.keras.layers.Dense(number_of_authors, activation='softmax'))
    
    optimizer = optimizer(learning_rate=lr)
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metrics,
    )
    
    history = model.fit(
        train_ds.batch(batch_size),
        validation_data=valid_ds.batch(1),
        epochs=epochs,
        callbacks=[
            CSVLogger(current_path),
            es
        ]
    )
    
    prediction = model.predict(test_ds.batch(1))

    
    y_pred = prediction_to_labels(prediction)
    accuracy = accuracy_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)
    
    return save_experiment_info(
        current_path,
        ModelName=MODEL_NAME,
        BatchSize=batch_size,
        Optimizer=type(optimizer).__name__,
        Epochs=epochs,
        EmbeddingSize=emb_size,
        Time=BLANK,
        Accuracy=accuracy,
        LR=lr,
        Hits=0,
        Miss=0,
        Key=key,
        SeqLen=output_sequence_length,
        VocabSize=max_tokens,
        TrainableEmbedding=True,
        ConfMatrix=conf_matrix,
        ModelType="NORMAL",
        TransformerName=BLANK,
        NumberOfAuthors=number_of_authors
    )
In [1280]:
def generate_model_1_experiments():
    for embedding_size in EMB_SIZES:
        for vocab_size in [10000]:
            for author in [5, 15]:
                for seq_len in [200, 400]:
                    for key in ALL_KEYS:
                        for optimizer in [ADAM]:
                            for batch_size in BATCH_SIZES:
                                for epoch in [EPOCHS]:
                                    for lr in [LR]:
                                        yield lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author
In [1281]:
len(list(generate_model_1_experiments()))
Out[1281]:
32
In [1274]:
for exp_values in generate_model_1_experiments():
    lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author = exp_values
    run_dense_model(
        max_tokens=vocab_size,
        output_sequence_length=seq_len,
        number_of_authors=author,
        emb_size=embedding_size,
        key=key,
        loss=LOSS,
        optimizer=optimizer,
        metrics=METRICS,
        batch_size=batch_size,
        epochs=epoch,
        lr=lr
    )
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 20s 21ms/step - loss: 1.1092 - accuracy: 0.5457 - val_loss: 0.6625 - val_accuracy: 0.7520 - time: 19.6096
Epoch 2/10
897/897 [==============================] - 19s 21ms/step - loss: 0.5454 - accuracy: 0.8149 - val_loss: 0.6588 - val_accuracy: 0.7741 - time: 38.7481
Epoch 3/10
897/897 [==============================] - 19s 21ms/step - loss: 0.3251 - accuracy: 0.8951 - val_loss: 0.7480 - val_accuracy: 0.7776 - time: 57.6086
Epoch 4/10
897/897 [==============================] - 19s 21ms/step - loss: 0.2037 - accuracy: 0.9366 - val_loss: 0.8819 - val_accuracy: 0.7727 - time: 76.6590
Epoch 5/10
897/897 [==============================] - 19s 21ms/step - loss: 0.1362 - accuracy: 0.9587 - val_loss: 1.0665 - val_accuracy: 0.7686 - time: 95.5141
Saving to ./experiments/1/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 20s 22ms/step - loss: 1.1327 - accuracy: 0.5528 - val_loss: 0.6745 - val_accuracy: 0.7509 - time: 20.0904
Epoch 2/10
897/897 [==============================] - 19s 21ms/step - loss: 0.5607 - accuracy: 0.8075 - val_loss: 0.6359 - val_accuracy: 0.7798 - time: 39.3553
Epoch 3/10
897/897 [==============================] - 19s 21ms/step - loss: 0.3235 - accuracy: 0.8943 - val_loss: 0.7423 - val_accuracy: 0.7864 - time: 58.5663
Epoch 4/10
897/897 [==============================] - 19s 21ms/step - loss: 0.1981 - accuracy: 0.9376 - val_loss: 0.9064 - val_accuracy: 0.7813 - time: 77.7837
Epoch 5/10
897/897 [==============================] - 19s 21ms/step - loss: 0.1383 - accuracy: 0.9579 - val_loss: 1.0002 - val_accuracy: 0.7838 - time: 96.9771
Saving to ./experiments/2/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 19s 21ms/step - loss: 1.2255 - accuracy: 0.4703 - val_loss: 0.8191 - val_accuracy: 0.6439 - time: 18.9839
Epoch 2/10
897/897 [==============================] - 18s 20ms/step - loss: 0.6994 - accuracy: 0.7103 - val_loss: 0.6556 - val_accuracy: 0.7667 - time: 37.2729
Epoch 3/10
897/897 [==============================] - 18s 21ms/step - loss: 0.4104 - accuracy: 0.8544 - val_loss: 0.6998 - val_accuracy: 0.7926 - time: 55.8114
Epoch 4/10
897/897 [==============================] - 18s 20ms/step - loss: 0.2596 - accuracy: 0.9115 - val_loss: 0.8018 - val_accuracy: 0.7928 - time: 74.0906
Epoch 5/10
897/897 [==============================] - 18s 20ms/step - loss: 0.1830 - accuracy: 0.9377 - val_loss: 0.9401 - val_accuracy: 0.7890 - time: 92.3428
Saving to ./experiments/3/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 20s 21ms/step - loss: 1.1458 - accuracy: 0.5311 - val_loss: 0.6403 - val_accuracy: 0.7641 - time: 19.8549
Epoch 2/10
897/897 [==============================] - 19s 21ms/step - loss: 0.5118 - accuracy: 0.8236 - val_loss: 0.5438 - val_accuracy: 0.8133 - time: 38.8472
Epoch 3/10
897/897 [==============================] - 19s 21ms/step - loss: 0.2995 - accuracy: 0.9016 - val_loss: 0.6463 - val_accuracy: 0.8118 - time: 58.1010
Epoch 4/10
897/897 [==============================] - 19s 21ms/step - loss: 0.1938 - accuracy: 0.9365 - val_loss: 0.8115 - val_accuracy: 0.8100 - time: 77.2383
Epoch 5/10
897/897 [==============================] - 19s 21ms/step - loss: 0.1460 - accuracy: 0.9526 - val_loss: 0.8223 - val_accuracy: 0.8096 - time: 96.3258
Saving to ./experiments/4/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 24s 26ms/step - loss: 1.2396 - accuracy: 0.4879 - val_loss: 0.7236 - val_accuracy: 0.7349 - time: 24.0312
Epoch 2/10
897/897 [==============================] - 23s 26ms/step - loss: 0.6198 - accuracy: 0.7835 - val_loss: 0.5899 - val_accuracy: 0.7900 - time: 47.3286
Epoch 3/10
897/897 [==============================] - 23s 26ms/step - loss: 0.3762 - accuracy: 0.8778 - val_loss: 0.6444 - val_accuracy: 0.7909 - time: 70.3903
Epoch 4/10
897/897 [==============================] - 23s 26ms/step - loss: 0.2353 - accuracy: 0.9266 - val_loss: 0.7783 - val_accuracy: 0.7904 - time: 93.5067
Epoch 5/10
897/897 [==============================] - 23s 26ms/step - loss: 0.1597 - accuracy: 0.9511 - val_loss: 0.8636 - val_accuracy: 0.7906 - time: 116.8653
Saving to ./experiments/5/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 24s 26ms/step - loss: 1.2563 - accuracy: 0.4734 - val_loss: 0.8248 - val_accuracy: 0.6913 - time: 23.6612
Epoch 2/10
897/897 [==============================] - 23s 26ms/step - loss: 0.6661 - accuracy: 0.7624 - val_loss: 0.6137 - val_accuracy: 0.7837 - time: 46.7763
Epoch 3/10
897/897 [==============================] - 23s 26ms/step - loss: 0.3902 - accuracy: 0.8723 - val_loss: 0.6765 - val_accuracy: 0.7818 - time: 69.9891
Epoch 4/10
897/897 [==============================] - 23s 25ms/step - loss: 0.2441 - accuracy: 0.9226 - val_loss: 0.7264 - val_accuracy: 0.7920 - time: 92.5897
Epoch 5/10
897/897 [==============================] - 22s 25ms/step - loss: 0.1609 - accuracy: 0.9496 - val_loss: 0.8897 - val_accuracy: 0.7873 - time: 114.9780
Saving to ./experiments/6/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 23s 25ms/step - loss: 1.4970 - accuracy: 0.3287 - val_loss: 1.0639 - val_accuracy: 0.5627 - time: 22.8250
Epoch 2/10
897/897 [==============================] - 22s 24ms/step - loss: 0.8271 - accuracy: 0.6617 - val_loss: 0.6806 - val_accuracy: 0.7440 - time: 44.7483
Epoch 3/10
897/897 [==============================] - 22s 25ms/step - loss: 0.4942 - accuracy: 0.8202 - val_loss: 0.6229 - val_accuracy: 0.7909 - time: 66.9495
Epoch 4/10
897/897 [==============================] - 22s 25ms/step - loss: 0.3114 - accuracy: 0.8923 - val_loss: 0.6957 - val_accuracy: 0.8049 - time: 89.0135
Epoch 5/10
897/897 [==============================] - 22s 25ms/step - loss: 0.2108 - accuracy: 0.9272 - val_loss: 0.8779 - val_accuracy: 0.7997 - time: 111.1830
Epoch 6/10
897/897 [==============================] - 22s 25ms/step - loss: 0.1557 - accuracy: 0.9475 - val_loss: 0.9966 - val_accuracy: 0.7965 - time: 133.2216
Saving to ./experiments/7/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 24s 26ms/step - loss: 1.2243 - accuracy: 0.4825 - val_loss: 0.7541 - val_accuracy: 0.7225 - time: 23.6197
Epoch 2/10
897/897 [==============================] - 23s 25ms/step - loss: 0.5908 - accuracy: 0.7896 - val_loss: 0.5499 - val_accuracy: 0.8060 - time: 46.3566
Epoch 3/10
897/897 [==============================] - 22s 25ms/step - loss: 0.3404 - accuracy: 0.8870 - val_loss: 0.5833 - val_accuracy: 0.8177 - time: 68.8512
Epoch 4/10
897/897 [==============================] - 22s 25ms/step - loss: 0.2177 - accuracy: 0.9282 - val_loss: 0.7044 - val_accuracy: 0.8138 - time: 91.2466
Epoch 5/10
897/897 [==============================] - 23s 25ms/step - loss: 0.1563 - accuracy: 0.9484 - val_loss: 0.8454 - val_accuracy: 0.8041 - time: 114.0871
Saving to ./experiments/8/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 58s 21ms/step - loss: 2.0389 - accuracy: 0.3196 - val_loss: 1.6315 - val_accuracy: 0.4372 - time: 58.1988
Epoch 2/10
2690/2690 [==============================] - 58s 22ms/step - loss: 1.5636 - accuracy: 0.4617 - val_loss: 1.4995 - val_accuracy: 0.4932 - time: 116.4055
Epoch 3/10
2690/2690 [==============================] - 58s 22ms/step - loss: 1.3251 - accuracy: 0.5466 - val_loss: 1.4742 - val_accuracy: 0.5173 - time: 174.3456
Epoch 4/10
2690/2690 [==============================] - 57s 21ms/step - loss: 1.1326 - accuracy: 0.6159 - val_loss: 1.5002 - val_accuracy: 0.5404 - time: 231.5611
Epoch 5/10
2690/2690 [==============================] - 57s 21ms/step - loss: 0.9792 - accuracy: 0.6727 - val_loss: 1.5381 - val_accuracy: 0.5526 - time: 288.9038
Epoch 6/10
2690/2690 [==============================] - 58s 22ms/step - loss: 0.8644 - accuracy: 0.7159 - val_loss: 1.6331 - val_accuracy: 0.5532 - time: 346.8530
Saving to ./experiments/9/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 58s 21ms/step - loss: 2.0078 - accuracy: 0.3345 - val_loss: 1.6180 - val_accuracy: 0.4401 - time: 57.7445
Epoch 2/10
2690/2690 [==============================] - 57s 21ms/step - loss: 1.5541 - accuracy: 0.4640 - val_loss: 1.4908 - val_accuracy: 0.4930 - time: 115.1954
Epoch 3/10
2690/2690 [==============================] - 57s 21ms/step - loss: 1.3211 - accuracy: 0.5473 - val_loss: 1.4854 - val_accuracy: 0.5198 - time: 171.8556
Epoch 4/10
2690/2690 [==============================] - 56s 21ms/step - loss: 1.1390 - accuracy: 0.6110 - val_loss: 1.5313 - val_accuracy: 0.5374 - time: 228.2602
Epoch 5/10
2690/2690 [==============================] - 57s 21ms/step - loss: 0.9905 - accuracy: 0.6658 - val_loss: 1.5988 - val_accuracy: 0.5395 - time: 284.9560
Epoch 6/10
2690/2690 [==============================] - 55s 20ms/step - loss: 0.8797 - accuracy: 0.7073 - val_loss: 1.6603 - val_accuracy: 0.5481 - time: 339.6589
Saving to ./experiments/10/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 55s 20ms/step - loss: 2.0945 - accuracy: 0.3059 - val_loss: 1.6696 - val_accuracy: 0.4403 - time: 55.1934
Epoch 2/10
2690/2690 [==============================] - 55s 20ms/step - loss: 1.5408 - accuracy: 0.4838 - val_loss: 1.3899 - val_accuracy: 0.5398 - time: 109.7722
Epoch 3/10
2690/2690 [==============================] - 55s 20ms/step - loss: 1.2604 - accuracy: 0.5822 - val_loss: 1.3462 - val_accuracy: 0.5763 - time: 164.6123
Epoch 4/10
2690/2690 [==============================] - 53s 20ms/step - loss: 1.0757 - accuracy: 0.6473 - val_loss: 1.3511 - val_accuracy: 0.5968 - time: 217.2784
Epoch 5/10
2690/2690 [==============================] - 55s 20ms/step - loss: 0.9294 - accuracy: 0.7023 - val_loss: 1.3809 - val_accuracy: 0.6075 - time: 271.9472
Epoch 6/10
2690/2690 [==============================] - 53s 20ms/step - loss: 0.8217 - accuracy: 0.7410 - val_loss: 1.4647 - val_accuracy: 0.6103 - time: 325.1410
Saving to ./experiments/11/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 58s 21ms/step - loss: 2.0427 - accuracy: 0.3319 - val_loss: 1.6068 - val_accuracy: 0.4555 - time: 57.6614
Epoch 2/10
2690/2690 [==============================] - 57s 21ms/step - loss: 1.4889 - accuracy: 0.4952 - val_loss: 1.3843 - val_accuracy: 0.5312 - time: 114.4740
Epoch 3/10
2690/2690 [==============================] - 57s 21ms/step - loss: 1.2211 - accuracy: 0.5902 - val_loss: 1.3423 - val_accuracy: 0.5706 - time: 171.9404
Epoch 4/10
2690/2690 [==============================] - 56s 21ms/step - loss: 1.0366 - accuracy: 0.6577 - val_loss: 1.3896 - val_accuracy: 0.5866 - time: 228.2280
Epoch 5/10
2690/2690 [==============================] - 56s 21ms/step - loss: 0.8927 - accuracy: 0.7113 - val_loss: 1.4266 - val_accuracy: 0.5997 - time: 284.4357
Epoch 6/10
2690/2690 [==============================] - 56s 21ms/step - loss: 0.7831 - accuracy: 0.7487 - val_loss: 1.4492 - val_accuracy: 0.6127 - time: 340.2862
Saving to ./experiments/12/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 68s 25ms/step - loss: 2.0768 - accuracy: 0.3215 - val_loss: 1.6491 - val_accuracy: 0.4290 - time: 68.1557
Epoch 2/10
2690/2690 [==============================] - 65s 24ms/step - loss: 1.5745 - accuracy: 0.4561 - val_loss: 1.4989 - val_accuracy: 0.4890 - time: 133.4462
Epoch 3/10
2690/2690 [==============================] - 67s 25ms/step - loss: 1.3356 - accuracy: 0.5427 - val_loss: 1.4696 - val_accuracy: 0.5233 - time: 200.5947
Epoch 4/10
2690/2690 [==============================] - 67s 25ms/step - loss: 1.1526 - accuracy: 0.6099 - val_loss: 1.4971 - val_accuracy: 0.5414 - time: 268.0899
Epoch 5/10
2690/2690 [==============================] - 68s 25ms/step - loss: 0.9991 - accuracy: 0.6648 - val_loss: 1.5784 - val_accuracy: 0.5458 - time: 336.0488
Epoch 6/10
2690/2690 [==============================] - 67s 25ms/step - loss: 0.8811 - accuracy: 0.7091 - val_loss: 1.6463 - val_accuracy: 0.5537 - time: 403.1251
Saving to ./experiments/13/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 69s 26ms/step - loss: 2.1037 - accuracy: 0.3074 - val_loss: 1.6977 - val_accuracy: 0.4146 - time: 69.4531
Epoch 2/10
2690/2690 [==============================] - 67s 25ms/step - loss: 1.6014 - accuracy: 0.4520 - val_loss: 1.4982 - val_accuracy: 0.4971 - time: 136.9077
Epoch 3/10
2690/2690 [==============================] - 68s 25ms/step - loss: 1.3412 - accuracy: 0.5446 - val_loss: 1.4765 - val_accuracy: 0.5253 - time: 204.5841
Epoch 4/10
2690/2690 [==============================] - 68s 25ms/step - loss: 1.1489 - accuracy: 0.6122 - val_loss: 1.5000 - val_accuracy: 0.5382 - time: 272.7900
Epoch 5/10
2690/2690 [==============================] - 68s 25ms/step - loss: 0.9979 - accuracy: 0.6655 - val_loss: 1.5711 - val_accuracy: 0.5493 - time: 340.7916
Epoch 6/10
2690/2690 [==============================] - 69s 25ms/step - loss: 0.8836 - accuracy: 0.7058 - val_loss: 1.6070 - val_accuracy: 0.5548 - time: 409.3469
Saving to ./experiments/14/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 65s 24ms/step - loss: 2.1489 - accuracy: 0.2800 - val_loss: 1.8111 - val_accuracy: 0.3535 - time: 64.8290
Epoch 2/10
2690/2690 [==============================] - 66s 25ms/step - loss: 1.7113 - accuracy: 0.3945 - val_loss: 1.5883 - val_accuracy: 0.4578 - time: 130.9426
Epoch 3/10
2690/2690 [==============================] - 65s 24ms/step - loss: 1.4490 - accuracy: 0.4916 - val_loss: 1.4911 - val_accuracy: 0.4995 - time: 195.9402
Epoch 4/10
2690/2690 [==============================] - 65s 24ms/step - loss: 1.2560 - accuracy: 0.5620 - val_loss: 1.4724 - val_accuracy: 0.5236 - time: 260.5758
Epoch 5/10
2690/2690 [==============================] - 65s 24ms/step - loss: 1.1043 - accuracy: 0.6202 - val_loss: 1.4690 - val_accuracy: 0.5491 - time: 326.0909
Epoch 6/10
2690/2690 [==============================] - 66s 24ms/step - loss: 0.9787 - accuracy: 0.6690 - val_loss: 1.4998 - val_accuracy: 0.5604 - time: 391.7768
Epoch 7/10
2690/2690 [==============================] - 66s 24ms/step - loss: 0.8763 - accuracy: 0.7058 - val_loss: 1.5598 - val_accuracy: 0.5729 - time: 457.4595
Epoch 8/10
2690/2690 [==============================] - 65s 24ms/step - loss: 0.8010 - accuracy: 0.7346 - val_loss: 1.5760 - val_accuracy: 0.5733 - time: 522.4126
Saving to ./experiments/15/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 69s 25ms/step - loss: 2.1053 - accuracy: 0.3026 - val_loss: 1.6754 - val_accuracy: 0.4143 - time: 68.5011
Epoch 2/10
2690/2690 [==============================] - 68s 25ms/step - loss: 1.5558 - accuracy: 0.4586 - val_loss: 1.4043 - val_accuracy: 0.5211 - time: 136.1559
Epoch 3/10
2690/2690 [==============================] - 68s 25ms/step - loss: 1.2706 - accuracy: 0.5615 - val_loss: 1.3419 - val_accuracy: 0.5576 - time: 204.3615
Epoch 4/10
2690/2690 [==============================] - 68s 25ms/step - loss: 1.0901 - accuracy: 0.6244 - val_loss: 1.3487 - val_accuracy: 0.5815 - time: 272.1085
Epoch 5/10
2690/2690 [==============================] - 68s 25ms/step - loss: 0.9526 - accuracy: 0.6788 - val_loss: 1.3749 - val_accuracy: 0.5987 - time: 339.7991
Epoch 6/10
2690/2690 [==============================] - 69s 26ms/step - loss: 0.8404 - accuracy: 0.7207 - val_loss: 1.4078 - val_accuracy: 0.6082 - time: 408.7929
Saving to ./experiments/16/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 43s 47ms/step - loss: 1.2061 - accuracy: 0.5213 - val_loss: 0.7380 - val_accuracy: 0.7222 - time: 42.8878
Epoch 2/10
897/897 [==============================] - 41s 46ms/step - loss: 0.5877 - accuracy: 0.7960 - val_loss: 0.6602 - val_accuracy: 0.7658 - time: 84.3094
Epoch 3/10
897/897 [==============================] - 41s 46ms/step - loss: 0.3010 - accuracy: 0.9012 - val_loss: 0.7774 - val_accuracy: 0.7722 - time: 125.6398
Epoch 4/10
897/897 [==============================] - 41s 46ms/step - loss: 0.1749 - accuracy: 0.9461 - val_loss: 0.9172 - val_accuracy: 0.7683 - time: 166.7339
Epoch 5/10
897/897 [==============================] - 41s 46ms/step - loss: 0.1109 - accuracy: 0.9665 - val_loss: 1.0060 - val_accuracy: 0.7639 - time: 207.7955
Saving to ./experiments/17/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 43s 47ms/step - loss: 1.1219 - accuracy: 0.5644 - val_loss: 0.6812 - val_accuracy: 0.7573 - time: 42.9350
Epoch 2/10
897/897 [==============================] - 42s 46ms/step - loss: 0.5310 - accuracy: 0.8199 - val_loss: 0.6584 - val_accuracy: 0.7799 - time: 84.6618
Epoch 3/10
897/897 [==============================] - 42s 46ms/step - loss: 0.2790 - accuracy: 0.9114 - val_loss: 0.7617 - val_accuracy: 0.7885 - time: 126.4006
Epoch 4/10
897/897 [==============================] - 41s 46ms/step - loss: 0.1603 - accuracy: 0.9515 - val_loss: 0.8541 - val_accuracy: 0.7856 - time: 167.3357
Epoch 5/10
897/897 [==============================] - 42s 46ms/step - loss: 0.1132 - accuracy: 0.9660 - val_loss: 0.9667 - val_accuracy: 0.7854 - time: 208.9221
Saving to ./experiments/18/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 41s 45ms/step - loss: 1.2983 - accuracy: 0.4404 - val_loss: 0.8450 - val_accuracy: 0.6599 - time: 41.1660
Epoch 2/10
897/897 [==============================] - 40s 45ms/step - loss: 0.6719 - accuracy: 0.7503 - val_loss: 0.6063 - val_accuracy: 0.7831 - time: 81.3468
Epoch 3/10
897/897 [==============================] - 39s 43ms/step - loss: 0.3585 - accuracy: 0.8751 - val_loss: 0.6537 - val_accuracy: 0.8036 - time: 120.3311
Epoch 4/10
897/897 [==============================] - 40s 45ms/step - loss: 0.2178 - accuracy: 0.9277 - val_loss: 0.7792 - val_accuracy: 0.8030 - time: 160.6190
Epoch 5/10
897/897 [==============================] - 41s 46ms/step - loss: 0.1476 - accuracy: 0.9496 - val_loss: 0.8659 - val_accuracy: 0.8031 - time: 201.4358
Saving to ./experiments/19/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 42s 46ms/step - loss: 1.1131 - accuracy: 0.5664 - val_loss: 0.6342 - val_accuracy: 0.7694 - time: 42.3203
Epoch 2/10
897/897 [==============================] - 41s 45ms/step - loss: 0.4970 - accuracy: 0.8301 - val_loss: 0.5627 - val_accuracy: 0.8014 - time: 83.0308
Epoch 3/10
897/897 [==============================] - 41s 45ms/step - loss: 0.2636 - accuracy: 0.9124 - val_loss: 0.6593 - val_accuracy: 0.8094 - time: 123.8527
Epoch 4/10
897/897 [==============================] - 40s 45ms/step - loss: 0.1684 - accuracy: 0.9451 - val_loss: 0.7868 - val_accuracy: 0.8050 - time: 164.2680
Epoch 5/10
897/897 [==============================] - 41s 46ms/step - loss: 0.1226 - accuracy: 0.9602 - val_loss: 0.8197 - val_accuracy: 0.7995 - time: 205.1059
Saving to ./experiments/20/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 76s 84ms/step - loss: 1.4314 - accuracy: 0.3718 - val_loss: 1.1860 - val_accuracy: 0.4921 - time: 75.8315
Epoch 2/10
897/897 [==============================] - 73s 81ms/step - loss: 0.9765 - accuracy: 0.5907 - val_loss: 0.7914 - val_accuracy: 0.7023 - time: 148.4724
Epoch 3/10
897/897 [==============================] - 73s 82ms/step - loss: 0.5856 - accuracy: 0.7853 - val_loss: 0.6904 - val_accuracy: 0.7548 - time: 221.8757
Epoch 4/10
897/897 [==============================] - 71s 80ms/step - loss: 0.3427 - accuracy: 0.8844 - val_loss: 0.7478 - val_accuracy: 0.7689 - time: 293.3946
Epoch 5/10
897/897 [==============================] - 69s 77ms/step - loss: 0.2151 - accuracy: 0.9300 - val_loss: 0.8492 - val_accuracy: 0.7686 - time: 362.3219
Epoch 6/10
897/897 [==============================] - 71s 79ms/step - loss: 0.1472 - accuracy: 0.9546 - val_loss: 1.0010 - val_accuracy: 0.7689 - time: 433.5334
Saving to ./experiments/21/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 72s 80ms/step - loss: 1.2526 - accuracy: 0.4923 - val_loss: 0.8532 - val_accuracy: 0.6709 - time: 72.2412
Epoch 2/10
897/897 [==============================] - 71s 79ms/step - loss: 0.7181 - accuracy: 0.7276 - val_loss: 0.7140 - val_accuracy: 0.7448 - time: 143.0743
Epoch 3/10
897/897 [==============================] - 72s 80ms/step - loss: 0.4157 - accuracy: 0.8593 - val_loss: 0.6968 - val_accuracy: 0.7798 - time: 214.9911
Epoch 4/10
897/897 [==============================] - 71s 79ms/step - loss: 0.2424 - accuracy: 0.9235 - val_loss: 0.7680 - val_accuracy: 0.7853 - time: 286.2484
Epoch 5/10
897/897 [==============================] - 72s 80ms/step - loss: 0.1611 - accuracy: 0.9519 - val_loss: 0.8226 - val_accuracy: 0.7791 - time: 358.0559
Epoch 6/10
897/897 [==============================] - 71s 79ms/step - loss: 0.1120 - accuracy: 0.9671 - val_loss: 0.9719 - val_accuracy: 0.7857 - time: 428.7859
Saving to ./experiments/22/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 72s 79ms/step - loss: 1.6130 - accuracy: 0.2575 - val_loss: 1.6093 - val_accuracy: 0.1923 - time: 71.8384
Epoch 2/10
897/897 [==============================] - 71s 79ms/step - loss: 1.6099 - accuracy: 0.2005 - val_loss: 1.6093 - val_accuracy: 0.1992 - time: 143.0636
Epoch 3/10
897/897 [==============================] - 72s 81ms/step - loss: 1.6103 - accuracy: 0.2002 - val_loss: 1.6096 - val_accuracy: 0.1920 - time: 215.4408
Epoch 4/10
897/897 [==============================] - 72s 81ms/step - loss: 1.6100 - accuracy: 0.2012 - val_loss: 1.6096 - val_accuracy: 0.1987 - time: 287.8035
Epoch 5/10
897/897 [==============================] - 73s 82ms/step - loss: 1.6105 - accuracy: 0.1998 - val_loss: 1.6096 - val_accuracy: 0.1922 - time: 360.9053
Saving to ./experiments/23/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 73s 81ms/step - loss: 1.6123 - accuracy: 0.1997 - val_loss: 1.6084 - val_accuracy: 0.1922 - time: 72.8000
Epoch 2/10
897/897 [==============================] - 71s 79ms/step - loss: 1.4137 - accuracy: 0.3396 - val_loss: 1.1041 - val_accuracy: 0.5329 - time: 144.0263
Epoch 3/10
897/897 [==============================] - 71s 79ms/step - loss: 0.8928 - accuracy: 0.6406 - val_loss: 0.7459 - val_accuracy: 0.7194 - time: 215.0285
Epoch 4/10
897/897 [==============================] - 72s 80ms/step - loss: 0.5330 - accuracy: 0.8096 - val_loss: 0.6141 - val_accuracy: 0.7989 - time: 286.6653
Epoch 5/10
897/897 [==============================] - 73s 82ms/step - loss: 0.3148 - accuracy: 0.8956 - val_loss: 0.6382 - val_accuracy: 0.8080 - time: 360.2254
Epoch 6/10
897/897 [==============================] - 73s 81ms/step - loss: 0.2174 - accuracy: 0.9292 - val_loss: 0.6781 - val_accuracy: 0.8146 - time: 432.7831
Epoch 7/10
897/897 [==============================] - 71s 79ms/step - loss: 0.1608 - accuracy: 0.9483 - val_loss: 0.7662 - val_accuracy: 0.8127 - time: 503.9513
Saving to ./experiments/24/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 122s 45ms/step - loss: 2.0102 - accuracy: 0.3290 - val_loss: 1.6453 - val_accuracy: 0.4280 - time: 122.0549
Epoch 2/10
2690/2690 [==============================] - 120s 45ms/step - loss: 1.5291 - accuracy: 0.4641 - val_loss: 1.5359 - val_accuracy: 0.4795 - time: 241.9263
Epoch 3/10
2690/2690 [==============================] - 122s 45ms/step - loss: 1.2336 - accuracy: 0.5628 - val_loss: 1.5504 - val_accuracy: 0.5069 - time: 364.0079
Epoch 4/10
2690/2690 [==============================] - 122s 45ms/step - loss: 0.9965 - accuracy: 0.6463 - val_loss: 1.6414 - val_accuracy: 0.5103 - time: 486.2525
Epoch 5/10
2690/2690 [==============================] - 122s 45ms/step - loss: 0.8289 - accuracy: 0.7106 - val_loss: 1.6834 - val_accuracy: 0.5260 - time: 607.9542
Saving to ./experiments/25/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 123s 45ms/step - loss: 2.0808 - accuracy: 0.3102 - val_loss: 1.6946 - val_accuracy: 0.4060 - time: 122.9349
Epoch 2/10
2690/2690 [==============================] - 121s 45ms/step - loss: 1.5866 - accuracy: 0.4425 - val_loss: 1.5601 - val_accuracy: 0.4671 - time: 244.0780
Epoch 3/10
2690/2690 [==============================] - 122s 45ms/step - loss: 1.3025 - accuracy: 0.5354 - val_loss: 1.5680 - val_accuracy: 0.4814 - time: 366.3781
Epoch 4/10
2690/2690 [==============================] - 120s 45ms/step - loss: 1.0726 - accuracy: 0.6133 - val_loss: 1.6285 - val_accuracy: 0.4982 - time: 486.3974
Epoch 5/10
2690/2690 [==============================] - 122s 45ms/step - loss: 0.9043 - accuracy: 0.6749 - val_loss: 1.6716 - val_accuracy: 0.5089 - time: 608.0671
Saving to ./experiments/26/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 125s 46ms/step - loss: 2.2296 - accuracy: 0.2580 - val_loss: 1.8096 - val_accuracy: 0.3592 - time: 124.4856
Epoch 2/10
2690/2690 [==============================] - 123s 46ms/step - loss: 1.6584 - accuracy: 0.4205 - val_loss: 1.5306 - val_accuracy: 0.4859 - time: 248.0259
Epoch 3/10
2690/2690 [==============================] - 121s 45ms/step - loss: 1.3361 - accuracy: 0.5402 - val_loss: 1.4402 - val_accuracy: 0.5341 - time: 368.8101
Epoch 4/10
2690/2690 [==============================] - 122s 46ms/step - loss: 1.0787 - accuracy: 0.6410 - val_loss: 1.4426 - val_accuracy: 0.5686 - time: 491.3656
Epoch 5/10
2690/2690 [==============================] - 117s 44ms/step - loss: 0.8880 - accuracy: 0.7134 - val_loss: 1.4858 - val_accuracy: 0.5840 - time: 608.8162
Epoch 6/10
2690/2690 [==============================] - 120s 45ms/step - loss: 0.7560 - accuracy: 0.7655 - val_loss: 1.5254 - val_accuracy: 0.5940 - time: 729.1473
Saving to ./experiments/27/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 123s 45ms/step - loss: 1.9839 - accuracy: 0.3485 - val_loss: 1.5318 - val_accuracy: 0.4723 - time: 122.9904
Epoch 2/10
2690/2690 [==============================] - 122s 46ms/step - loss: 1.4141 - accuracy: 0.5101 - val_loss: 1.3959 - val_accuracy: 0.5345 - time: 245.4391
Epoch 3/10
2690/2690 [==============================] - 122s 45ms/step - loss: 1.1213 - accuracy: 0.6072 - val_loss: 1.3805 - val_accuracy: 0.5586 - time: 367.5131
Epoch 4/10
2690/2690 [==============================] - 121s 45ms/step - loss: 0.9080 - accuracy: 0.6855 - val_loss: 1.4342 - val_accuracy: 0.5822 - time: 488.1874
Epoch 5/10
2690/2690 [==============================] - 124s 46ms/step - loss: 0.7550 - accuracy: 0.7454 - val_loss: 1.4890 - val_accuracy: 0.5930 - time: 611.9943
Epoch 6/10
2690/2690 [==============================] - 122s 45ms/step - loss: 0.6421 - accuracy: 0.7907 - val_loss: 1.5445 - val_accuracy: 0.6021 - time: 733.5118
Saving to ./experiments/28/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 220s 81ms/step - loss: 2.2987 - accuracy: 0.2511 - val_loss: 2.0420 - val_accuracy: 0.3019 - time: 219.6768
Epoch 2/10
2690/2690 [==============================] - 214s 80ms/step - loss: 1.9066 - accuracy: 0.3432 - val_loss: 1.7355 - val_accuracy: 0.4036 - time: 433.9882
Epoch 3/10
2690/2690 [==============================] - 217s 81ms/step - loss: 1.5962 - accuracy: 0.4386 - val_loss: 1.6540 - val_accuracy: 0.4393 - time: 650.6270
Epoch 4/10
2690/2690 [==============================] - 221s 82ms/step - loss: 1.3669 - accuracy: 0.5185 - val_loss: 1.6590 - val_accuracy: 0.4497 - time: 871.8096
Epoch 5/10
2690/2690 [==============================] - 206s 77ms/step - loss: 1.1789 - accuracy: 0.5871 - val_loss: 1.6885 - val_accuracy: 0.4685 - time: 1078.1029
Epoch 6/10
2690/2690 [==============================] - 208s 77ms/step - loss: 1.0267 - accuracy: 0.6456 - val_loss: 1.7340 - val_accuracy: 0.4851 - time: 1286.4270
Saving to ./experiments/29/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 212s 79ms/step - loss: 2.2158 - accuracy: 0.2619 - val_loss: 1.9355 - val_accuracy: 0.3273 - time: 212.3794
Epoch 2/10
2690/2690 [==============================] - 205s 76ms/step - loss: 1.8288 - accuracy: 0.3532 - val_loss: 1.7168 - val_accuracy: 0.3944 - time: 417.8752
Epoch 3/10
2690/2690 [==============================] - 210s 78ms/step - loss: 1.5736 - accuracy: 0.4344 - val_loss: 1.6661 - val_accuracy: 0.4275 - time: 627.8725
Epoch 4/10
2690/2690 [==============================] - 212s 79ms/step - loss: 1.3595 - accuracy: 0.5067 - val_loss: 1.6519 - val_accuracy: 0.4645 - time: 839.9593
Epoch 5/10
2690/2690 [==============================] - 207s 77ms/step - loss: 1.1728 - accuracy: 0.5773 - val_loss: 1.6700 - val_accuracy: 0.4849 - time: 1047.2926
Epoch 6/10
2690/2690 [==============================] - 214s 80ms/step - loss: 1.0239 - accuracy: 0.6329 - val_loss: 1.7395 - val_accuracy: 0.4991 - time: 1261.5815
Epoch 7/10
2690/2690 [==============================] - 214s 80ms/step - loss: 0.9046 - accuracy: 0.6801 - val_loss: 1.8413 - val_accuracy: 0.4966 - time: 1475.8997
Saving to ./experiments/30/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 214s 79ms/step - loss: 2.7099 - accuracy: 0.1087 - val_loss: 2.7085 - val_accuracy: 0.0659 - time: 213.5496
Epoch 2/10
2690/2690 [==============================] - 215s 80ms/step - loss: 2.7083 - accuracy: 0.0674 - val_loss: 2.7085 - val_accuracy: 0.0658 - time: 428.6687
Epoch 3/10
2690/2690 [==============================] - 218s 81ms/step - loss: 2.7085 - accuracy: 0.0663 - val_loss: 2.7085 - val_accuracy: 0.0649 - time: 646.2720
Epoch 4/10
2690/2690 [==============================] - 223s 83ms/step - loss: 2.7084 - accuracy: 0.0661 - val_loss: 2.7084 - val_accuracy: 0.0649 - time: 869.2077
Epoch 5/10
2690/2690 [==============================] - 218s 81ms/step - loss: 2.7082 - accuracy: 0.0661 - val_loss: 2.7081 - val_accuracy: 0.0650 - time: 1086.8918
Epoch 6/10
2690/2690 [==============================] - 218s 81ms/step - loss: 2.4676 - accuracy: 0.1570 - val_loss: 2.2634 - val_accuracy: 0.2210 - time: 1304.6082
Epoch 7/10
2690/2690 [==============================] - 222s 82ms/step - loss: 2.1338 - accuracy: 0.2583 - val_loss: 2.0906 - val_accuracy: 0.2700 - time: 1526.2462
Epoch 8/10
2690/2690 [==============================] - 213s 79ms/step - loss: 1.9281 - accuracy: 0.3258 - val_loss: 1.9309 - val_accuracy: 0.3318 - time: 1739.2393
Epoch 9/10
2690/2690 [==============================] - 225s 84ms/step - loss: 1.6518 - accuracy: 0.4234 - val_loss: 1.7446 - val_accuracy: 0.4059 - time: 1964.0190
Epoch 10/10
2690/2690 [==============================] - 219s 82ms/step - loss: 1.4152 - accuracy: 0.5107 - val_loss: 1.7144 - val_accuracy: 0.4340 - time: 2183.3675
Saving to ./experiments/31/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 229s 85ms/step - loss: 2.1107 - accuracy: 0.2927 - val_loss: 1.7549 - val_accuracy: 0.3950 - time: 229.2097
Epoch 2/10
2690/2690 [==============================] - 212s 79ms/step - loss: 1.6085 - accuracy: 0.4436 - val_loss: 1.4837 - val_accuracy: 0.5008 - time: 441.2030
Epoch 3/10
2690/2690 [==============================] - 215s 80ms/step - loss: 1.2858 - accuracy: 0.5538 - val_loss: 1.4181 - val_accuracy: 0.5407 - time: 655.9851
Epoch 4/10
2690/2690 [==============================] - 217s 81ms/step - loss: 1.0564 - accuracy: 0.6399 - val_loss: 1.4044 - val_accuracy: 0.5668 - time: 873.4003
Epoch 5/10
2690/2690 [==============================] - 219s 81ms/step - loss: 0.8795 - accuracy: 0.7080 - val_loss: 1.4369 - val_accuracy: 0.5816 - time: 1092.1185
Epoch 6/10
2690/2690 [==============================] - 217s 81ms/step - loss: 0.7448 - accuracy: 0.7617 - val_loss: 1.4501 - val_accuracy: 0.5995 - time: 1309.1906
Epoch 7/10
2690/2690 [==============================] - 221s 82ms/step - loss: 0.6385 - accuracy: 0.8015 - val_loss: 1.4799 - val_accuracy: 0.6026 - time: 1530.5399
Saving to ./experiments/32/sum.csv

Komplexnější model¶

In [1294]:
def run_rnn_model(
    max_tokens,
    output_sequence_length,
    number_of_authors,
    emb_size,
    key,
    loss,
    optimizer,
    metrics,
    batch_size,
    epochs,
    lr
):
    MODEL_NAME = "Bidirectional GRU"
    current_path = setup_directory()
    
    current_data = data[str(number_of_authors)][key]
    loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
    encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
    
    X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
    
    y_test = encoder.transform(y_test)
    y_train = encoder.transform(y_train)
    y_valid = encoder.transform(y_valid)
    
    train_ds = create_dataset_from_Xy(X_train, y_train)
    test_ds = create_dataset_from_Xy(X_test, y_test)
    valid_ds = create_dataset_from_Xy(X_valid, y_valid)
    
    
    vector_layer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens,
        output_mode='int',
        standardize=None,
        output_sequence_length=output_sequence_length,
        split='whitespace'
    )

    vector_layer.adapt(train_ds.map(lambda x, y: x))
    
    model = tf.keras.Sequential()

    model.add(tf.keras.Input(shape=(1,), dtype=tf.string))

    model.add(vector_layer)

    model.add(
        tf.keras.layers.Embedding(
            max_tokens + 1, 
            emb_size,
            mask_zero = True
        )
    )


    model.add(tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(64, activation='relu', return_sequences=True, dropout=0.2, recurrent_dropout=0.2)
    ))

    model.add(
        tf.keras.layers.GRU(64, activation='relu', return_sequences=False)
    )

    model.add(tf.keras.layers.Dense(32, activation='relu'))

    model.add(tf.keras.layers.Dropout(rate=0.2))

    model.add(tf.keras.layers.Dense(32, activation='relu'))

    model.add(tf.keras.layers.Dropout(rate=0.3))

    model.add(tf.keras.layers.Dense(64, activation='relu'))

    model.add(tf.keras.layers.Dense(number_of_authors, activation='softmax'))

    optimizer = optimizer(learning_rate=lr)
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metrics,
    )

    history = model.fit(
        train_ds.batch(batch_size),
        validation_data=valid_ds.batch(1),
        epochs=epochs,
        callbacks=[
            CSVLogger(current_path),
            es
        ]
    )
    
    prediction = model.predict(test_ds.batch(1))

    
    y_pred = prediction_to_labels(prediction)
    accuracy = accuracy_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)
    
    return save_experiment_info(
        current_path,
        ModelName=MODEL_NAME,
        BatchSize=batch_size,
        Optimizer=type(optimizer).__name__,
        Epochs=epochs,
        EmbeddingSize=emb_size,
        Time=BLANK,
        Accuracy=accuracy,
        LR=lr,
        Hits=0,
        Miss=0,
        Key=key,
        SeqLen=output_sequence_length,
        VocabSize=max_tokens,
        TrainableEmbedding=True,
        ConfMatrix=conf_matrix,
        ModelType="NORMAL",
        TransformerName=BLANK,
        NumberOfAuthors=number_of_authors
    )
In [1295]:
def generate_model_rnn_experiments():
    for embedding_size in EMB_SIZES:
        for vocab_size in [10000]:
            for author in [5, 15]:
                for seq_len in [300]:
                    for key in ALL_KEYS:
                        for optimizer in [ADAM]:
                            for batch_size in BATCH_SIZES:
                                for epoch in [EPOCHS]:
                                    for lr in [LR]:
                                        yield lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author
In [1296]:
len(list(generate_model_rnn_experiments()))
Out[1296]:
16
In [ ]:
for exp_values in generate_model_rnn_experiments():
    lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author = exp_values
    run_rnn_model(
        max_tokens=vocab_size,
        output_sequence_length=seq_len,
        number_of_authors=author,
        emb_size=embedding_size,
        key=key,
        loss=LOSS,
        optimizer=optimizer,
        metrics=METRICS,
        batch_size=batch_size,
        epochs=epoch,
        lr=lr
    )
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 1093s 1s/step - loss: 6687085056.0000 - accuracy: 0.3819 - val_loss: nan - val_accuracy: 0.2941 - time: 1092.7117
Epoch 2/10
897/897 [==============================] - 1146s 1s/step - loss: 1187478372352.0000 - accuracy: 0.4749 - val_loss: 1.0665 - val_accuracy: 0.5434 - time: 2239.0744
Epoch 3/10
897/897 [==============================] - 1082s 1s/step - loss: 1.3208 - accuracy: 0.5978 - val_loss: 0.9181 - val_accuracy: 0.6209 - time: 3320.7732
Epoch 4/10
897/897 [==============================] - 1050s 1s/step - loss: 5227681792.0000 - accuracy: 0.6492 - val_loss: 0.8942 - val_accuracy: 0.6402 - time: 4371.0521
Epoch 5/10
897/897 [==============================] - 1045s 1s/step - loss: 0.8841 - accuracy: 0.6834 - val_loss: 0.8578 - val_accuracy: 0.6551 - time: 5416.5511
Epoch 6/10
897/897 [==============================] - 1078s 1s/step - loss: 0.7737 - accuracy: 0.6969 - val_loss: 0.8511 - val_accuracy: 0.6684 - time: 6494.4732
Epoch 7/10
897/897 [==============================] - 1319s 1s/step - loss: 1.5510 - accuracy: 0.7337 - val_loss: 0.8364 - val_accuracy: 0.6836 - time: 7813.6105
Epoch 8/10
897/897 [==============================] - 1404s 2s/step - loss: 0.6452 - accuracy: 0.7529 - val_loss: 0.8382 - val_accuracy: 0.6761 - time: 9240.8331
Epoch 9/10
897/897 [==============================] - 1337s 1s/step - loss: 0.6104 - accuracy: 0.7729 - val_loss: 0.8286 - val_accuracy: 0.6885 - time: 10577.5128
Epoch 10/10
897/897 [==============================] - 1290s 1s/step - loss: 0.5759 - accuracy: 0.7884 - val_loss: 0.8020 - val_accuracy: 0.7101 - time: 11867.2054
Saving to ./experiments/32/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 1369s 2s/step - loss: 2406436.7500 - accuracy: 0.3964 - val_loss: 1.1757 - val_accuracy: 0.4856 - time: 1368.7052
Epoch 2/10
897/897 [==============================] - 1432s 2s/step - loss: 25465.0742 - accuracy: 0.5265 - val_loss: 1.1074 - val_accuracy: 0.5081 - time: 2801.0649
Epoch 3/10
897/897 [==============================] - 1506s 2s/step - loss: nan - accuracy: 0.3837 - val_loss: nan - val_accuracy: 0.1922 - time: 4307.1053
Epoch 4/10
897/897 [==============================] - 1382s 2s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 5688.8658
Epoch 5/10
897/897 [==============================] - 1845s 2s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 7534.0251
Saving to ./experiments/33/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 1795s 2s/step - loss: 1.3198 - accuracy: 0.3473 - val_loss: 1.1397 - val_accuracy: 0.4634 - time: 1794.7318
Epoch 2/10
897/897 [==============================] - 1556s 2s/step - loss: 183.8856 - accuracy: 0.5911 - val_loss: 0.7558 - val_accuracy: 0.7106 - time: 3350.3680
Epoch 3/10
897/897 [==============================] - 1191s 1s/step - loss: 0.6741 - accuracy: 0.7487 - val_loss: 0.7135 - val_accuracy: 0.7355 - time: 4541.5603
Epoch 4/10
897/897 [==============================] - 1185s 1s/step - loss: 0.5505 - accuracy: 0.7981 - val_loss: 202.5703 - val_accuracy: 0.7536 - time: 5726.7569
Epoch 5/10
897/897 [==============================] - 1184s 1s/step - loss: nan - accuracy: 0.2459 - val_loss: nan - val_accuracy: 0.1922 - time: 6911.1217
Epoch 6/10
897/897 [==============================] - 1180s 1s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 8091.5012
Saving to ./experiments/34/sum.csv
Train (57375,)
Valid (6375,)
Test (11250,)
Epoch 1/10
897/897 [==============================] - 1202s 1s/step - loss: 278055.2812 - accuracy: 0.3279 - val_loss: 1.2133 - val_accuracy: 0.4411 - time: 1202.2194
Epoch 2/10
897/897 [==============================] - 1192s 1s/step - loss: 1.4219 - accuracy: 0.4619 - val_loss: 1.1394 - val_accuracy: 0.4924 - time: 2394.1133
Epoch 3/10
897/897 [==============================] - 1187s 1s/step - loss: 31504.0977 - accuracy: 0.5164 - val_loss: 1.1058 - val_accuracy: 0.5269 - time: 3581.0423
Epoch 4/10
897/897 [==============================] - 1184s 1s/step - loss: 1.0101 - accuracy: 0.5739 - val_loss: 1.0839 - val_accuracy: 0.5523 - time: 4764.7697
Epoch 5/10
897/897 [==============================] - 1191s 1s/step - loss: 0.9372 - accuracy: 0.6264 - val_loss: 0.8990 - val_accuracy: 0.6436 - time: 5955.9341
Epoch 6/10
897/897 [==============================] - 1185s 1s/step - loss: 0.7581 - accuracy: 0.6962 - val_loss: 0.8012 - val_accuracy: 0.6907 - time: 7140.5361
Epoch 7/10
897/897 [==============================] - 1188s 1s/step - loss: 0.6762 - accuracy: 0.7385 - val_loss: 0.7789 - val_accuracy: 0.7209 - time: 8328.5538
Epoch 8/10
897/897 [==============================] - 1192s 1s/step - loss: 0.5939 - accuracy: 0.7797 - val_loss: 0.7381 - val_accuracy: 0.7454 - time: 9520.1441
Epoch 9/10
897/897 [==============================] - 1187s 1s/step - loss: 666630.4375 - accuracy: 0.7840 - val_loss: 0.7031 - val_accuracy: 0.7437 - time: 10707.5617
Epoch 10/10
897/897 [==============================] - 1185s 1s/step - loss: 0.5122 - accuracy: 0.8174 - val_loss: 0.7276 - val_accuracy: 0.7478 - time: 11892.9372
Saving to ./experiments/35/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 3581s 1s/step - loss: nan - accuracy: 0.0903 - val_loss: nan - val_accuracy: 0.0685 - time: 3581.0230
Epoch 2/10
2690/2690 [==============================] - 3662s 1s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 7243.3928
Epoch 3/10
2690/2690 [==============================] - 6001s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 13244.8738
Saving to ./experiments/36/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 5273s 2s/step - loss: nan - accuracy: 0.0667 - val_loss: nan - val_accuracy: 0.0685 - time: 5272.8774
Epoch 2/10
2690/2690 [==============================] - 4828s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 10100.6725
Epoch 3/10
2690/2690 [==============================] - 4549s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 14649.6744
Saving to ./experiments/37/sum.csv
Train (172125,)
Valid (19125,)
Test (33750,)
Epoch 1/10
2690/2690 [==============================] - 4745s 2s/step - loss: 143974.9844 - accuracy: 0.1826 - val_loss: 2.1231 - val_accuracy: 0.2460 - time: 4745.0529
Epoch 2/10
1493/2690 [===============>..............] - ETA: 17:04 - loss: 2.0408 - accuracy: 0.2720

Transfer learning¶

In [1305]:
class TransformerName(Enum):
    DistilBertBaseUncased = "distilbert-base-uncased"
    BertBaseUncased = "bert-base-uncased"
    ElectraSmall = "google/electra-small-discriminator"
    Blank = BLANK_DESCRIPTION
In [1306]:
from transformers import TFAutoModel
from transformers import AutoTokenizer
In [1307]:
def tokenize(sentences, tokenizer, max_length, padding='max_length'):
    return tokenizer(
        sentences,
        truncation=True,
        padding=padding,
        max_length=max_length,
        return_tensors="tf"
    )
In [1308]:
def run_transformer_model(
    transformer_name,
    output_sequence_length,
    number_of_authors,
    key,
    loss,
    optimizer,
    metrics,
    batch_size,
    epochs,
    lr
):
    MODEL_NAME = "Transformer"
    current_path = setup_directory()
    
    tokenizer = AutoTokenizer.from_pretrained(transformer_name)
    
    current_data = data[str(number_of_authors)][key]
    loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
    encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
    
    X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
    
    y_test = encoder.transform(y_test)
    y_train = encoder.transform(y_train)
    y_valid = encoder.transform(y_valid)
    
    
    
    train_ds = tf.data.Dataset.from_tensor_slices((
    dict(tokenize(list(X_train), tokenizer, output_sequence_length)),
    y_train
    )).batch(batch_size).prefetch(1)


    valid_ds = tf.data.Dataset.from_tensor_slices((
        dict(tokenize(list(X_valid), tokenizer, output_sequence_length)),
        y_valid
    )).batch(batch_size).prefetch(1)

    test_ds = tf.data.Dataset.from_tensor_slices((
        dict(tokenize(list(X_test), tokenizer, output_sequence_length)),
        y_test
    )).batch(1).prefetch(1)
    
    
    base = TFAutoModel.from_pretrained(
        transformer_name,
    )
    
    input_ids = tf.keras.layers.Input(shape=(output_sequence_length,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.layers.Input((output_sequence_length,), dtype=tf.int32, name='attention_mask')
    
    #Selection of cls
    output = base([input_ids, attention_mask]).last_hidden_state[:, 0, :]

    
    output = tf.keras.layers.Dropout(
        rate=0.15,
    )(output)

    output = tf.keras.layers.Dense(
        units=64,
        activation='relu',
    )(output)

    output = tf.keras.layers.BatchNormalization()(output)

    output = tf.keras.layers.Dense(
        units=64,
        activation='relu',
    )(output)

    output = tf.keras.layers.BatchNormalization()(output)

    output_layer = tf.keras.layers.Dense(
        units=number_of_authors,
        activation='softmax'
    )(output)


    model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output_layer)

    model.summary()


    optimizer = optimizer(learning_rate=lr)
    
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metrics,
    )
    
    history = model.fit(
        train_ds,
        validation_data=valid_ds,
        epochs=epochs,
        callbacks=[
            CSVLogger(current_path),
            es
        ]
    )
    
    prediction = model.predict(test_ds)

    
    y_pred = prediction_to_labels(prediction)
    accuracy = accuracy_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)
    
    return save_experiment_info(
        current_path,
        ModelName=MODEL_NAME,
        BatchSize=batch_size,
        Optimizer=type(optimizer).__name__,
        Epochs=epochs,
        EmbeddingSize=BLANK,
        Time=BLANK,
        Accuracy=accuracy,
        LR=lr,
        Hits=0,
        Miss=0,
        Key=key,
        SeqLen=output_sequence_length,
        VocabSize=BLANK,
        TrainableEmbedding=True,
        ConfMatrix=conf_matrix,
        ModelType="TL",
        TransformerName=transformer_name,
        NumberOfAuthors=number_of_authors
    )
In [1309]:
def generate_model_transformer_experiments():
    for transformer_name in [TransformerName.DistilBertBaseUncased.value, TransformerName.BertBaseUncased.value, TransformerName.ElectraSmall.value ]:
            for author in [5]:
                for seq_len in [300]:
                    for key in ["LOWER_I"]:
                        for optimizer in [ADAM]:
                            for batch_size in [128]:
                                for epoch in [3]:
                                    for lr in TRANSFORMER_LR:
                                        yield (
                                            transformer_name,
                                            seq_len,
                                            author,
                                            key,
                                            LOSS,
                                            optimizer,
                                            METRICS,
                                            batch_size,
                                            epoch,
                                            lr
                                        )
In [1397]:
list(generate_model_transformer_experiments())
Out[1397]:
[('distilbert-base-uncased',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  0.001),
 ('distilbert-base-uncased',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  5e-05),
 ('bert-base-uncased',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  0.001),
 ('bert-base-uncased',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  5e-05),
 ('google/electra-small-discriminator',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  0.001),
 ('google/electra-small-discriminator',
  300,
  5,
  'LOWER_I',
  <keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
  keras.optimizer_v2.adam.Adam,
  [<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
  128,
  3,
  5e-05)]
In [1310]:
len(list(generate_model_transformer_experiments()))
Out[1310]:
6
In [ ]:
for exp_values in generate_model_transformer_experiments():
    run_transformer_model(
        *exp_values
    )
Train (57375,)
Valid (6375,)
Test (11250,)
Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_projector', 'vocab_transform', 'activation_13', 'vocab_layer_norm']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.
Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_ids (InputLayer)         [(None, 300)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 300)]        0           []                               
                                                                                                  
 tf_distil_bert_model_5 (TFDist  TFBaseModelOutput(l  66362880   ['input_ids[0][0]',              
 ilBertModel)                   ast_hidden_state=(N               'attention_mask[0][0]']         
                                one, 300, 768),                                                   
                                 hidden_states=None                                               
                                , attentions=None)                                                
                                                                                                  
 tf.__operators__.getitem_5 (Sl  (None, 768)         0           ['tf_distil_bert_model_5[0][0]'] 
 icingOpLambda)                                                                                   
                                                                                                  
 dropout_309 (Dropout)          (None, 768)          0           ['tf.__operators__.getitem_5[0][0
                                                                 ]']                              
                                                                                                  
 dense_291 (Dense)              (None, 64)           49216       ['dropout_309[0][0]']            
                                                                                                  
 batch_normalization_11 (BatchN  (None, 64)          256         ['dense_291[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 dense_292 (Dense)              (None, 64)           4160        ['batch_normalization_11[0][0]'] 
                                                                                                  
 batch_normalization_12 (BatchN  (None, 64)          256         ['dense_292[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 dense_293 (Dense)              (None, 5)            325         ['batch_normalization_12[0][0]'] 
                                                                                                  
==================================================================================================
Total params: 66,417,093
Trainable params: 66,416,837
Non-trainable params: 256
__________________________________________________________________________________________________
Epoch 1/3
449/449 [==============================] - 12732s 28s/step - loss: 1.7213 - accuracy: 0.1676 - val_loss: 1.6598 - val_accuracy: 0.1987 - time: 12731.7159
Epoch 2/3
314/449 [===================>..........] - ETA: 1:01:37 - loss: 1.6307 - accuracy: 0.2010

Vyhodnocení¶

In [2]:
from os.path import exists as file_exists
In [13]:
filenames = [
    SUMMAR,
    LOG,
]
In [14]:
class Storage:
    def __init__(self):
        self.records = []

    def reset(self):
        self.records = []

    def run(self, directory=None):
        self.directory = directory

        if self.directory is None:
            return

        create_dataframe(self.directory, self.records)

    def get_dataframe(self):
        mapped = map(lambda x: x.iloc[1, :].values, self.records)
        df = pd.DataFrame(mapped)
        new_header = self.records[0].iloc[0, :].values
        df.columns = new_header
        return df
    
def create_dataframe(start_directory, storage=None):
    return process_directory(start_directory, storage)

def process_directory(directory, storage=None):
    is_correct = is_correct_file(directory)
    record = None

    if is_correct:
        if storage is not None:
            record = create_record(directory)
            if record is not None:
                storage.append(record)

    for current_directory in os.listdir(directory):
        deeper_level = os.path.sep.join([directory, current_directory])
        if os.path.isdir(deeper_level):
            process_directory(deeper_level, storage)

        
def is_correct_file(path):
    for filename in filenames:
        current_path = os.path.sep.join([path, filename])
        if os.path.exists(current_path):
            return True
    return False


def exists(directory, filename):
    current_path = os.path.sep.join([directory, filename])
    if file_exists(current_path):
        return current_path
    return None

def create_record(directory):
    try:
        log = parse_log(directory)
        summ = parse_summa(directory)
        record = merge_content(
            summ,
            log
        )
        return record
    except Exception as e:
        print(f"Exception in {directory}")
        print(f"Exception {e}")
        return None
    

def parse_summa(directory):
    path = exists(directory, SUMMAR)

    if path is None:
        return None

    content = pd.read_csv(path, sep=";", header=None)
    return content

def parse_log(directory):
    path = exists(directory, LOG)

    if path is None:
        return None
    
    content = pd.read_csv(path, sep=";")
    
    
    dic = {}
    for index in range(content.shape[1]):
        key = content.columns[index]
        if "Unnamed" not in key:
            dic[key] = [content.iloc[:, index].values]
    
            
    res = pd.DataFrame.from_dict(dic, orient="index").reset_index()
    res.columns = [0, 1]
    return res

def merge_content(log=pd.DataFrame(), summ=pd.DataFrame()):
    concat_df = pd.concat([summ, log])

    record = concat_df.T

    return record

Načtení dat¶

In [257]:
start_directory = os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY)
In [258]:
storage = Storage()
In [259]:
storage.run(start_directory)
In [260]:
pd.set_option('display.max_columns', None)
df = storage.get_dataframe()
In [261]:
df.index = list(range(len(df)))
In [262]:
len(df)
Out[262]:
53
In [263]:
import math
In [264]:
df['Accuracy'] = list(map(lambda x: round(float(x), 2), df['Accuracy'].values))
In [265]:
df
Out[265]:
loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors
0 [1.1091737747192385, 0.545360803604126, 0.3251... [0.545690655708313, 0.8149368166923523, 0.8950... [0.6624863743782043, 0.658841609954834, 0.7480... [0.7519999742507935, 0.774117648601532, 0.7775... [19.60962748527527, 38.74808216094971, 57.6085... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 RAW 200 10000 True [[1678 158 256 36 85]\n [ 172 1459 199 ... NORMAL - 5
1 [1.1327065229415894, 0.5607202053070068, 0.323... [0.5527529120445251, 0.807494580745697, 0.8943... [0.67454594373703, 0.6359243392944336, 0.74231... [0.7509019374847412, 0.779764711856842, 0.7863... [20.090360403060917, 39.35527753829956, 58.566... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 LOWER 200 10000 True [[1488 306 238 45 136]\n [ 85 1690 127 ... NORMAL - 5
2 [1.225493311882019, 0.6993909478187561, 0.4104... [0.4702745079994201, 0.7103093862533569, 0.854... [0.8190874457359314, 0.6555845737457275, 0.699... [0.6439215540885925, 0.7667450904846191, 0.792... [18.983893871307373, 37.27286505699158, 55.811... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 DEFAULT 200 10000 True [[1451 183 376 77 126]\n [ 280 1587 93 ... NORMAL - 5
3 [1.1457595825195312, 0.5117649435997009, 0.299... [0.5310588479042053, 0.8236165642738342, 0.901... [0.6402604579925537, 0.5437538027763367, 0.646... [0.7640784382820129, 0.8133333325386047, 0.811... [19.854929447174072, 38.84715795516968, 58.101... 0 DENSE 64 Adam 0.001 - 50 - 0.81 0 0 LOWER_I 200 10000 True [[1670 150 205 51 137]\n [ 198 1715 102 ... NORMAL - 5
4 [1.2395521402359009, 0.6198447942733765, 0.376... [0.4878901839256286, 0.7835119962692261, 0.877... [0.7236012816429138, 0.5898836255073547, 0.644... [0.7349019646644592, 0.7899608016014099, 0.790... [24.031248092651367, 47.32864117622376, 70.390... 0 DENSE 64 Adam 0.001 - 50 - 0.79 0 0 RAW 400 10000 True [[1691 143 276 33 70]\n [ 141 1608 183 ... NORMAL - 5
5 [1.256301760673523, 0.6661115288734436, 0.3901... [0.4733960926532745, 0.7623529434204102, 0.872... [0.8248462677001953, 0.6136951446533203, 0.676... [0.6912941336631775, 0.7836862802505493, 0.781... [23.661153078079224, 46.77628660202026, 69.989... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 LOWER 400 10000 True [[1750 172 193 31 67]\n [ 198 1542 185 ... NORMAL - 5
6 [1.4970093965530396, 0.8270591497421265, 0.494... [0.3286588191986084, 0.6616818904876709, 0.820... [1.0638797283172607, 0.6806007027626038, 0.622... [0.562666654586792, 0.7440000176429749, 0.7909... [22.825047254562374, 44.74828147888184, 66.949... 0 DENSE 64 Adam 0.001 - 50 - 0.79 0 0 DEFAULT 400 10000 True [[1574 149 312 64 114]\n [ 230 1586 166 ... NORMAL - 5
7 [1.224252700805664, 0.5908034443855286, 0.3403... [0.4824941158294678, 0.7896470427513123, 0.886... [0.7540565133094788, 0.5498793125152588, 0.583... [0.7225098013877869, 0.8059607744216919, 0.817... [23.619723081588745, 46.35664224624634, 68.851... 0 DENSE 64 Adam 0.001 - 50 - 0.80 0 0 LOWER_I 400 10000 True [[1559 158 332 46 118]\n [ 127 1619 214 ... NORMAL - 5
8 [2.038879871368408, 1.563637614250183, 1.32510... [0.3195910453796386, 0.4616528749465942, 0.546... [1.6314889192581177, 1.4994858503341677, 1.474... [0.4371764659881592, 0.4931764602661133, 0.517... [58.19883918762207, 116.40550565719604, 174.34... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 RAW 200 10000 True [[1382 246 16 48 47 53 275 73 19... NORMAL - 15
9 [2.007798671722412, 1.554110050201416, 1.32113... [0.3344993591308594, 0.4640406668186188, 0.547... [1.6179845333099363, 1.4908241033554075, 1.485... [0.4401045739650726, 0.4930196106433868, 0.519... [57.74453091621399, 115.19541382789612, 171.85... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 LOWER 200 10000 True [[1182 218 46 45 52 60 311 104 10... NORMAL - 15
10 [2.0945076942443848, 1.5407600402832031, 1.260... [0.3058666586875915, 0.4837879538536072, 0.582... [1.6695539951324463, 1.3899223804473877, 1.346... [0.4403137266635895, 0.5398169755935669, 0.576... [55.19337034225464, 109.7722339630127, 164.612... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 DEFAULT 200 10000 True [[1521 213 24 13 44 56 146 17 9... NORMAL - 15
11 [2.0426511764526367, 1.4889464378356934, 1.221... [0.331916332244873, 0.4952156841754913, 0.5902... [1.6068023443222046, 1.3842854499816897, 1.342... [0.4554771184921264, 0.5311895608901978, 0.570... [57.66135025024414, 114.47401738166808, 171.94... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 LOWER_I 200 10000 True [[1535 153 14 8 31 47 290 53 11... NORMAL - 15
12 [2.0767838954925537, 1.5745370388031006, 1.335... [0.3215215802192688, 0.4560929536819458, 0.542... [1.649091720581055, 1.498886227607727, 1.46961... [0.4290196001529693, 0.4890457391738891, 0.523... [68.15565466880798, 133.44615292549133, 200.59... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 RAW 400 10000 True [[1258 179 48 47 35 38 323 37 10... NORMAL - 15
13 [2.10368275642395, 1.6013576984405518, 1.34118... [0.3073516488075256, 0.4520435631275177, 0.544... [1.697721242904663, 1.4982041120529177, 1.4765... [0.4145882427692413, 0.4970980286598205, 0.525... [69.45309066772461, 136.9076645374298, 204.584... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 LOWER 400 10000 True [[1307 134 32 37 30 72 318 55 22... NORMAL - 15
14 [2.148930549621582, 1.7112888097763062, 1.4490... [0.2799529433250427, 0.3945388495922088, 0.491... [1.811055302619934, 1.5882717370986938, 1.4911... [0.3535163402557373, 0.4577777683734894, 0.499... [64.82900762557983, 130.9426231384277, 195.940... 0 DENSE 64 Adam 0.001 - 50 - 0.55 0 0 DEFAULT 400 10000 True [[1439 88 23 22 27 74 353 36 18... NORMAL - 15
15 [2.1052839756011963, 1.5557817220687866, 1.270... [0.3025882244110107, 0.4585562944412231, 0.561... [1.6754242181777954, 1.404281497001648, 1.3419... [0.414274513721466, 0.5210980176925659, 0.5575... [68.50114560127258, 136.1558701992035, 204.361... 0 DENSE 64 Adam 0.001 - 50 - 0.56 0 0 LOWER_I 400 10000 True [[1434 184 20 21 17 45 317 107 14... NORMAL - 15
16 [1.2060670852661133, 0.5876715183258057, 0.301... [0.5213333368301392, 0.7960087060928345, 0.901... [0.7380061745643616, 0.6602489352226257, 0.777... [0.722196102142334, 0.7658039331436157, 0.7722... [42.88776779174805, 84.30938196182251, 125.639... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 RAW 200 10000 True [[1499 249 353 35 77]\n [ 109 1561 244 ... NORMAL - 5
17 [1.1218571662902832, 0.5309804677963257, 0.279... [0.5644392371177673, 0.8198866844177246, 0.911... [0.6812100410461426, 0.6584259867668152, 0.761... [0.7573333382606506, 0.7799215912818909, 0.788... [42.935046672821045, 84.6618320941925, 126.400... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 LOWER 200 10000 True [[1526 250 228 57 152]\n [ 111 1533 188 ... NORMAL - 5
18 [1.2983185052871704, 0.67186439037323, 0.35850... [0.4403764605522156, 0.7503268122673035, 0.875... [0.8450406789779663, 0.6062559485435486, 0.653... [0.6599215865135193, 0.7830588221549988, 0.803... [41.16602444648743, 81.34677410125732, 120.331... 0 DENSE 64 Adam 0.001 - 300 - 0.79 0 0 DEFAULT 200 10000 True [[1461 291 227 76 158]\n [ 174 1726 94 ... NORMAL - 5
19 [1.1131271123886108, 0.4970025420188904, 0.263... [0.5663529634475708, 0.8301350474357605, 0.912... [0.634172260761261, 0.5626846551895142, 0.6592... [0.7694117426872253, 0.8014117479324341, 0.809... [42.32026171684265, 83.03080439567566, 123.852... 0 DENSE 64 Adam 0.001 - 300 - 0.80 0 0 LOWER_I 200 10000 True [[1508 290 287 32 96]\n [ 77 1770 190 ... NORMAL - 5
20 [1.4313589334487915, 0.976482629776001, 0.5856... [0.3718431293964386, 0.5906579494476318, 0.785... [1.1860140562057495, 0.7914318442344666, 0.690... [0.492078423500061, 0.7022745013237, 0.7548235... [75.83154845237732, 148.4723603725433, 221.875... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 RAW 400 10000 True [[1521 245 329 36 82]\n [ 127 1526 201 ... NORMAL - 5
21 [1.2526357173919678, 0.7180959582328796, 0.415... [0.4922509789466858, 0.7275816798210144, 0.859... [0.8532054424285889, 0.7139995694160461, 0.696... [0.6709019541740417, 0.7447842955589294, 0.779... [72.2411801815033, 143.07429265975952, 214.991... 0 DENSE 64 Adam 0.001 - 300 - 0.77 0 0 LOWER 400 10000 True [[1541 148 435 16 73]\n [ 74 1643 317 ... NORMAL - 5
22 [1.612979292869568, 1.6098767518997192, 1.6102... [0.2574901878833771, 0.2004705816507339, 0.200... [1.6093101501464844, 1.6092772483825684, 1.609... [0.1923137307167053, 0.199215680360794, 0.1920... [71.83844518661499, 143.0635724067688, 215.440... 0 DENSE 64 Adam 0.001 - 300 - 0.21 0 0 DEFAULT 400 10000 True [[ 3 0 2210 0 0]\n [ 3 0 2200 ... NORMAL - 5
23 [1.6123055219650269, 1.4136512279510498, 0.892... [0.1997019648551941, 0.339607834815979, 0.6405... [1.6083794832229614, 1.10409414768219, 0.74592... [0.1921568661928177, 0.5328627228736877, 0.719... [72.79997491836548, 144.0263090133667, 215.028... 0 DENSE 64 Adam 0.001 - 300 - 0.79 0 0 LOWER_I 400 10000 True [[1617 342 145 20 89]\n [ 121 1746 103 ... NORMAL - 5
24 [2.010152578353882, 1.5290956497192385, 1.2335... [0.3290140032768249, 0.4641103744506836, 0.562... [1.6453301906585691, 1.5359132289886477, 1.550... [0.4279738664627075, 0.4795294106006622, 0.506... [122.05491280555724, 241.92626881599423, 364.0... 0 DENSE 64 Adam 0.001 - 300 - 0.49 0 0 RAW 200 10000 True [[1071 369 36 61 34 64 271 114 7... NORMAL - 15
25 [2.0807976722717285, 1.5866154432296753, 1.302... [0.3101803958415985, 0.4424575269222259, 0.535... [1.694566249847412, 1.5601444244384766, 1.5679... [0.4060130715370178, 0.4670849740505218, 0.481... [122.93492102622986, 244.07795214653012, 366.3... 0 DENSE 64 Adam 0.001 - 300 - 0.47 0 0 LOWER 200 10000 True [[1157 287 34 93 38 72 240 96 18... NORMAL - 15
26 [2.22964859008789, 1.658429503440857, 1.336077... [0.257966011762619, 0.4204676747322082, 0.5402... [1.8096245527267456, 1.5306107997894287, 1.440... [0.3592156767845154, 0.4859085083007812, 0.534... [124.48556709289552, 248.0258502960205, 368.81... 0 DENSE 64 Adam 0.001 - 300 - 0.54 0 0 DEFAULT 200 10000 True [[1180 105 23 97 50 109 413 26 7... NORMAL - 15
27 [1.9838707447052, 1.4141194820404053, 1.121265... [0.3484601378440857, 0.5100711584091187, 0.607... [1.531771898269653, 1.3959068059921265, 1.3805... [0.4722614288330078, 0.5345359444618225, 0.558... [122.99041819572447, 245.43910098075867, 367.5... 0 DENSE 64 Adam 0.001 - 300 - 0.56 0 0 LOWER_I 200 10000 True [[1257 309 28 8 77 50 398 8 3... NORMAL - 15
28 [2.298656940460205, 1.906553149223328, 1.59620... [0.2510797381401062, 0.3431517779827118, 0.438... [2.042023181915283, 1.7355191707611084, 1.6539... [0.3018562197685241, 0.4036078453063965, 0.439... [219.676766872406, 433.9881939888001, 650.6269... 0 DENSE 64 Adam 0.001 - 300 - 0.44 0 0 RAW 400 10000 True [[1050 246 42 204 46 94 260 77 4... NORMAL - 15
29 [2.215775489807129, 1.8288313150405884, 1.5736... [0.2619398832321167, 0.3532491028308868, 0.434... [1.9354660511016848, 1.716828465461731, 1.6661... [0.327320247888565, 0.3944052159786224, 0.4274... [212.37938380241397, 417.8751857280731, 627.87... 0 DENSE 64 Adam 0.001 - 300 - 0.47 0 0 LOWER 400 10000 True [[1142 226 8 64 130 116 306 88 8... NORMAL - 15
30 [2.7099225521087646, 2.708319902420044, 2.7084... [0.1086745113134384, 0.0673812627792358, 0.066... [2.7085378170013428, 2.708480834960937, 2.7085... [0.0658823549747467, 0.0657777786254882, 0.064... [213.5495536327362, 428.6686849594116, 646.272... 0 DENSE 64 Adam 0.001 - 300 - 0.43 0 0 DEFAULT 400 10000 True [[ 103 125 11 296 155 802 447 58 7... NORMAL - 15
31 [6687085056.0, 1187478372352.0, 1.320836067199... [0.3819360733032226, 0.4748932421207428, 0.597... [nan, 1.0665013790130615, 0.9180729389190674, ... [0.294117659330368, 0.5433725714683533, 0.6208... [1092.711680173874, 2239.074378967285, 3320.77... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.69 0 0 RAW 300 10000 True [[1428 125 512 34 114]\n [ 50 1417 393 ... NORMAL - 5
32 [2406436.75, 25465.07421875, nan, nan, nan] [0.3964235186576843, 0.5264662504196167, 0.383... [1.1757404804229736, 1.1074206829071045, nan, ... [0.485647052526474, 0.5080784559249878, 0.1921... [1368.7052409648895, 2801.06485581398, 4307.10... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.51 0 0 LOWER 300 10000 True [[1731 434 4 3 41]\n [ 601 1472 94 ... NORMAL - 5
33 [1.3197650909423828, 183.88555908203125, 0.674... [0.3472784459590912, 0.5911459922790527, 0.748... [1.1397396326065063, 0.7558140158653259, 0.713... [0.4633725583553314, 0.7105882167816162, 0.735... [1794.731845855713, 3350.3680169582367, 4541.5... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.74 0 0 DEFAULT 300 10000 True [[1087 167 730 68 161]\n [ 17 1742 242 ... NORMAL - 5
34 [278055.28125, 1.4219011068344116, 31504.09765... [0.3278745114803314, 0.4618736505508423, 0.516... [1.2132927179336548, 1.139415979385376, 1.1058... [0.4410980343818664, 0.4923921525478363, 0.526... [1202.219447374344, 2394.1133301258087, 3581.0... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.75 0 0 LOWER_I 300 10000 True [[1397 428 212 64 112]\n [ 143 1533 396 ... NORMAL - 5
35 [nan, nan, nan] [0.0903417393565177, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [3581.0230338573456, 7243.392804384232, 13244.... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 RAW 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15
36 [nan, nan, nan] [0.0666928067803382, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [5272.877393007278, 10100.672457456589, 14649.... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 LOWER 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15
37 [143974.984375, 1.981084942817688, nan, nan, nan] [0.1826091557741165, 0.2948496639728546, 0.301... [2.123115539550781, 1.8301618099212649, nan, n... [0.2459607869386673, 0.3584313690662384, 0.068... [4745.052874326706, 9517.231180667875, 14253.2... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.36 0 0 DEFAULT 300 10000 True [[ 552 57 29 485 1 132 234 68 4... NORMAL - 15
38 [nan, nan, nan] [0.0665464028716087, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [4647.039727687836, 9398.7957239151, 14184.882... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 LOWER_I 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15
39 [nan, nan, nan] [0.1690849661827087, 0.2015163451433181, 0.201... [nan, nan, nan] [0.1921568661928177, 0.1921568661928177, 0.192... [1732.056715965271, 3492.265405893326, 5230.93... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.20 0 0 RAW 300 10000 True [[2213 0 0 0 0]\n [2203 0 0 ... NORMAL - 5
40 [nan, nan, nan] [0.3445647060871124, 0.2015163451433181, 0.201... [nan, nan, nan] [0.1921568661928177, 0.1921568661928177, 0.192... [1745.630146026611, 3480.498616695404, 4916.81... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.20 0 0 LOWER 300 10000 True [[2213 0 0 0 0]\n [2203 0 0 ... NORMAL - 5
41 [9.50131130218506, 0.7876715660095215, 1.23750... [0.3992627561092376, 0.6950762271881104, 0.808... [0.9494749307632446, 0.6531278491020203, 0.620... [0.6191372275352478, 0.7658039331436157, 0.788... [1764.6829175949097, 3511.476901292801, 5241.2... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.80 0 0 DEFAULT 300 10000 True [[1597 259 167 86 104]\n [ 151 1747 64 ... NORMAL - 5
42 [117312464.0, nan, nan, nan] [0.4742588102817535, 0.4018823504447937, 0.201... [1.0666706562042236, nan, nan, nan] [0.5543529391288757, 0.1921568661928177, 0.192... [1787.1415581703186, 3542.0427582263947, 5307.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.57 0 0 LOWER_I 300 10000 True [[1077 131 929 3 73]\n [ 132 1246 712 ... NORMAL - 5
43 [nan, nan, nan] [0.1155686303973198, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [5241.320497512817, 10400.483037471771, 15546.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.07 0 0 RAW 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15
44 [4.139690399169922, 31051933696.0, 5946009.0, ... [0.2126849740743637, 0.3071604967117309, 0.362... [45.60334777832031, 1.8750523328781128, 1.7882... [0.2989281117916107, 0.3549803793430328, 0.382... [4035.43758225441, 7763.776393175125, 11458.88... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.56 0 0 LOWER 300 10000 True [[1488 176 59 4 126 19 158 42 5... NORMAL - 15
45 [3889620.75, 5.671187400817871, 174.6585388183... [0.2653124034404754, 0.3765461146831512, 0.458... [74537.96875, 1.6551882028579712, 3.5518321990... [0.3310849666595459, 0.4476862847805023, 0.468... [3747.2181475162506, 7466.645300388336, 11197.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.57 0 0 DEFAULT 300 10000 True [[1448 185 41 29 71 61 236 37 5... NORMAL - 15
46 [47510450176.0, 10461987.0, 3.044379234313965,... [0.2487633973360061, 0.3066957294940948, 0.376... [2.28581562481608e+19, 1.265042553581863e+18, ... [0.3026405274868011, 0.3697254955768585, 0.433... [3202.14945268631, 6400.369203567505, 9596.970... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.43 0 0 LOWER_I 300 10000 True [[ 918 165 76 21 2 162 512 56 9... NORMAL - 15
47 [1.721323847770691, 1.6283942461013794, 1.6182... [0.167633980512619, 0.200139433145523, 0.20036... [1.6598328351974487, 1.6358952522277832, 1.622... [0.1987451016902923, 0.1921568661928177, 0.192... [12731.715883731842, 25516.930138111115, 38309... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[2213 0 0 0 0]\n [2203 0 0 ... TL distilbert-base-uncased 5
48 [0.6788761615753174, 0.2942142188549042, 0.161... [0.6950744986534119, 0.8956339955329895, 0.945... [0.4689745604991913, 0.3302323520183563, 0.374... [0.8304314017295837, 0.8854901790618896, 0.882... [12792.493296146393, 25582.148589611053, 38391... 0 Transformer 128 Adam 5e-05 - - - 0.88 0 0 LOWER_I 300 - True [[1932 65 99 57 60]\n [ 96 1780 103 ... TL distilbert-base-uncased 5
49 [1.720706582069397, 1.62993323802948, 1.619087... [0.2695686221122741, 0.1993899792432785, 0.199... [1.612503170967102, 1.6107159852981567, 1.6184... [0.2062745094299316, 0.2043921500444412, 0.198... [25539.273556947708, 51105.00569915772, 76684.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL bert-base-uncased 5
50 [0.6129876375198364, 0.2475549280643463, 0.137... [0.7192941308021545, 0.9139520525932312, 0.954... [0.3863182961940765, 0.3083285987377167, 0.356... [0.8680784106254578, 0.8953725695610046, 0.892... [25577.77806091309, 51058.13135123253, 76546.3... 0 Transformer 128 Adam 5e-05 - - - 0.89 0 0 LOWER_I 300 - True [[1845 72 102 118 76]\n [ 34 1847 67 ... TL bert-base-uncased 5
51 [1.7034351825714111, 1.6270469427108765, 1.617... [0.2704313695430755, 0.198047935962677, 0.1987... [1.6337881088256836, 1.6143453121185305, 1.611... [0.2062745094299316, 0.2043921500444412, 0.198... [8417.139698982239, 16806.370790719986, 25206.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL google/electra-small-discriminator 5
52 [1.167790174484253, 0.668626070022583, 0.48968... [0.5210353136062622, 0.760610044002533, 0.8271... [0.8901610970497131, 0.8246893286705017, 0.576... [0.6842352747917175, 0.7243921756744385, 0.799... [8379.385590314865, 16772.64554834366, 25164.1... 0 Transformer 128 Adam 5e-05 - - - 0.81 0 0 LOWER_I 300 - True [[1545 117 274 131 146]\n [ 58 1393 243 ... TL google/electra-small-discriminator 5
In [266]:
df['CalculationTime'] = list(map(lambda x: round(np.sum(x), 2), df.time.values))

Shrnutí¶

Popis datové sady¶

Jak již bylo zmíněno datová sada byla vytvořena vlastními silami. Byl využit Projekt Gutenberg, který obsahuje umělecká díla. Z toho projektu pomocí R skriptu byly separovány díla s anglickým textem a zároveň taková, která obsahovala jasně specifikovaného autora. Tento fakt nám umožnil vytvořit finální datovou sadu. Tato datová sada textový řetězec reprezentovaný velikosti o n větách a identifikátor, který reprezentoval autora díla.

S touto datovou sadou bylo pracováno v diplomové práci, přičemž tento jupyter vznikl akorát pro tento projekt. Jupyter byl extra vytvořen s použitím části kódu z diplomové práce.

Chyba¶

Datová sada obsahovala identifikátory pro autory, přičemž zde byly využity datové sady o 5 a 15 autorech. Tyto identifikátory byly transformovaný pomocí LabelEncoderu do žádaného prostoru. Respektive čísel 0-5, 0-15. K samotnému určení, kdo umělecké dílo napsal byla využita klasifikace. Chybou proto byla využita SparseCategoricalCrossentropy, abychom nemuseli target vektory transformovat to one hot vektorů.

Výběr dat¶

Vytvořená datová sada neobsahovala stejný počet záznamů pro každého autora, a proto bylo přistoupeno k normalizaci na určitou hodnotu vzhledem k nejslabšímu autorovi. Pro představu zde v projektu byla využita normalizační hodnota 15 000. Takže ke každému autorovi bylo v datové sadě ponecháno 15 000 záznamů, které byly náhodně vybrány ze všech. Díky tohoto přístupu jsme zároveň mohli využít metriku, která bude zmíněna níže.

Metrika¶

Jako metrika byla využita přesnost. Toto jsme si mohli dovolit, díky toho že datová sada obsahovala stejný počet záznamů, pro každého autora. Číselná hodnota nám pak říkala, s jakou přesnosti jsme schopni určit autora vzhledem k autorům s kterými jsme pracovali.

Předzpracování dat¶

V rámci předzpracování dat bylo experimentováno s 4 přístupu:

  • Surová data.

    • Nedocházelo k žádnému předzpracování.
  • Data upravena pomocí gensim metody na předzpracování.

    • Pro referenci uvádím odkaz na metodu https://radimrehurek.com/gensim/parsing/preprocessing.html.
    • ... listové vypsání hodnotím za komplexní operaci.
  • Data transformována do malých písmen.

    • Dochází pouze k lower() na textový řetězec MAD.lower() = mad.
  • Data transformovaná do malých písmen společně s malým předzpracováním.

    • Odstranění slov, které jsou v blacklistu.
    • Transformace do malých písmen.
    • Odstranení interpunkčních znamének.
    • Odstranění bílých znaků.
    • Odstranění numerických znaků.

Účelem bylo porovnat jaký vliv má předzpracování na datovou sadu.

Aplikací předzpracování vždy ztrácíme určitou informaci. Většinou bývají textová data na tuto operaci náchylná a je nutné provádět předzpracování s rozvahou.

Po předzpracování můžeme extrahovat relevantní informace, zmenšit vstupní vektor záznamu, získat základní tvar slova, a tak využít již předučený číselný vektor a mnoho dalšího.

Experimenty¶

Obecný popis¶

Bylo provedeno 53 experimentů, které budou zpracovány. Jak již bylo zmíněno experimenty byly provedeny po 4 různých přístupech předzpracování. Zároveň byly využity různé přístupy a to:

  • Jednoduchý vlastní model založen na husté propojené síti s embedding vrstvou.
  • Vlastní model založen na sofistikovanější rekurentní neuronové síti s embedding vrstvou.
  • Využití state-of-art modelů založených na Transformer architektuře.

V rámci těchto tříd bylo provedeno postupně:

  • 32 experimentů s různou konfigurací.
  • 16 experimentů s různou konfigurací.
  • 6 experimentů s různou konfigurací.

Jednoduchý model a rekurentní model pracoval s různou velikostí vstupu a zároveň experimentoval s velikostí embeddingu, který byl učen jakožto reprezentace jednotlivých tokenů. Zároveň pro tyto modely byly testovány obě množiny s 5 a 15 autory.

Složitější Transformer model byl bohužel zpracován pouze v podobě 5 autorů a jednom druhu předzpracování. Za to byly využity 3 druhy Transformeru, abychom mohli porovnat jaký model si vede nejlépe. K tomuto porovnání byly vybrány modely:

  • Electra
  • Bert
  • DistilBert

Více specifický popis bude vytvořen u grafů, které budou vizualizovat dosažené výsledky.

Grafy¶

Jednoduchá hluboká neuronová síť¶

In [267]:
dense_df = df[df.ModelName == "DENSE"] 
In [268]:
dense_df
Out[268]:
loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors CalculationTime
0 [1.1091737747192385, 0.545360803604126, 0.3251... [0.545690655708313, 0.8149368166923523, 0.8950... [0.6624863743782043, 0.658841609954834, 0.7480... [0.7519999742507935, 0.774117648601532, 0.7775... [19.60962748527527, 38.74808216094971, 57.6085... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 RAW 200 10000 True [[1678 158 256 36 85]\n [ 172 1459 199 ... NORMAL - 5 288.14
1 [1.1327065229415894, 0.5607202053070068, 0.323... [0.5527529120445251, 0.807494580745697, 0.8943... [0.67454594373703, 0.6359243392944336, 0.74231... [0.7509019374847412, 0.779764711856842, 0.7863... [20.090360403060917, 39.35527753829956, 58.566... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 LOWER 200 10000 True [[1488 306 238 45 136]\n [ 85 1690 127 ... NORMAL - 5 292.77
2 [1.225493311882019, 0.6993909478187561, 0.4104... [0.4702745079994201, 0.7103093862533569, 0.854... [0.8190874457359314, 0.6555845737457275, 0.699... [0.6439215540885925, 0.7667450904846191, 0.792... [18.983893871307373, 37.27286505699158, 55.811... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 DEFAULT 200 10000 True [[1451 183 376 77 126]\n [ 280 1587 93 ... NORMAL - 5 278.50
3 [1.1457595825195312, 0.5117649435997009, 0.299... [0.5310588479042053, 0.8236165642738342, 0.901... [0.6402604579925537, 0.5437538027763367, 0.646... [0.7640784382820129, 0.8133333325386047, 0.811... [19.854929447174072, 38.84715795516968, 58.101... 0 DENSE 64 Adam 0.001 - 50 - 0.81 0 0 LOWER_I 200 10000 True [[1670 150 205 51 137]\n [ 198 1715 102 ... NORMAL - 5 290.37
4 [1.2395521402359009, 0.6198447942733765, 0.376... [0.4878901839256286, 0.7835119962692261, 0.877... [0.7236012816429138, 0.5898836255073547, 0.644... [0.7349019646644592, 0.7899608016014099, 0.790... [24.031248092651367, 47.32864117622376, 70.390... 0 DENSE 64 Adam 0.001 - 50 - 0.79 0 0 RAW 400 10000 True [[1691 143 276 33 70]\n [ 141 1608 183 ... NORMAL - 5 352.12
5 [1.256301760673523, 0.6661115288734436, 0.3901... [0.4733960926532745, 0.7623529434204102, 0.872... [0.8248462677001953, 0.6136951446533203, 0.676... [0.6912941336631775, 0.7836862802505493, 0.781... [23.661153078079224, 46.77628660202026, 69.989... 0 DENSE 64 Adam 0.001 - 50 - 0.77 0 0 LOWER 400 10000 True [[1750 172 193 31 67]\n [ 198 1542 185 ... NORMAL - 5 347.99
6 [1.4970093965530396, 0.8270591497421265, 0.494... [0.3286588191986084, 0.6616818904876709, 0.820... [1.0638797283172607, 0.6806007027626038, 0.622... [0.562666654586792, 0.7440000176429749, 0.7909... [22.825047254562374, 44.74828147888184, 66.949... 0 DENSE 64 Adam 0.001 - 50 - 0.79 0 0 DEFAULT 400 10000 True [[1574 149 312 64 114]\n [ 230 1586 166 ... NORMAL - 5 467.94
7 [1.224252700805664, 0.5908034443855286, 0.3403... [0.4824941158294678, 0.7896470427513123, 0.886... [0.7540565133094788, 0.5498793125152588, 0.583... [0.7225098013877869, 0.8059607744216919, 0.817... [23.619723081588745, 46.35664224624634, 68.851... 0 DENSE 64 Adam 0.001 - 50 - 0.80 0 0 LOWER_I 400 10000 True [[1559 158 332 46 118]\n [ 127 1619 214 ... NORMAL - 5 344.16
8 [2.038879871368408, 1.563637614250183, 1.32510... [0.3195910453796386, 0.4616528749465942, 0.546... [1.6314889192581177, 1.4994858503341677, 1.474... [0.4371764659881592, 0.4931764602661133, 0.517... [58.19883918762207, 116.40550565719604, 174.34... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 RAW 200 10000 True [[1382 246 16 48 47 53 275 73 19... NORMAL - 15 1216.27
9 [2.007798671722412, 1.554110050201416, 1.32113... [0.3344993591308594, 0.4640406668186188, 0.547... [1.6179845333099363, 1.4908241033554075, 1.485... [0.4401045739650726, 0.4930196106433868, 0.519... [57.74453091621399, 115.19541382789612, 171.85... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 LOWER 200 10000 True [[1182 218 46 45 52 60 311 104 10... NORMAL - 15 1197.67
10 [2.0945076942443848, 1.5407600402832031, 1.260... [0.3058666586875915, 0.4837879538536072, 0.582... [1.6695539951324463, 1.3899223804473877, 1.346... [0.4403137266635895, 0.5398169755935669, 0.576... [55.19337034225464, 109.7722339630127, 164.612... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 DEFAULT 200 10000 True [[1521 213 24 13 44 56 146 17 9... NORMAL - 15 1143.94
11 [2.0426511764526367, 1.4889464378356934, 1.221... [0.331916332244873, 0.4952156841754913, 0.5902... [1.6068023443222046, 1.3842854499816897, 1.342... [0.4554771184921264, 0.5311895608901978, 0.570... [57.66135025024414, 114.47401738166808, 171.94... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 LOWER_I 200 10000 True [[1535 153 14 8 31 47 290 53 11... NORMAL - 15 1197.03
12 [2.0767838954925537, 1.5745370388031006, 1.335... [0.3215215802192688, 0.4560929536819458, 0.542... [1.649091720581055, 1.498886227607727, 1.46961... [0.4290196001529693, 0.4890457391738891, 0.523... [68.15565466880798, 133.44615292549133, 200.59... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 RAW 400 10000 True [[1258 179 48 47 35 38 323 37 10... NORMAL - 15 1409.46
13 [2.10368275642395, 1.6013576984405518, 1.34118... [0.3073516488075256, 0.4520435631275177, 0.544... [1.697721242904663, 1.4982041120529177, 1.4765... [0.4145882427692413, 0.4970980286598205, 0.525... [69.45309066772461, 136.9076645374298, 204.584... 0 DENSE 64 Adam 0.001 - 50 - 0.53 0 0 LOWER 400 10000 True [[1307 134 32 37 30 72 318 55 22... NORMAL - 15 1433.87
14 [2.148930549621582, 1.7112888097763062, 1.4490... [0.2799529433250427, 0.3945388495922088, 0.491... [1.811055302619934, 1.5882717370986938, 1.4911... [0.3535163402557373, 0.4577777683734894, 0.499... [64.82900762557983, 130.9426231384277, 195.940... 0 DENSE 64 Adam 0.001 - 50 - 0.55 0 0 DEFAULT 400 10000 True [[1439 88 23 22 27 74 353 36 18... NORMAL - 15 2350.03
15 [2.1052839756011963, 1.5557817220687866, 1.270... [0.3025882244110107, 0.4585562944412231, 0.561... [1.6754242181777954, 1.404281497001648, 1.3419... [0.414274513721466, 0.5210980176925659, 0.5575... [68.50114560127258, 136.1558701992035, 204.361... 0 DENSE 64 Adam 0.001 - 50 - 0.56 0 0 LOWER_I 400 10000 True [[1434 184 20 21 17 45 317 107 14... NORMAL - 15 1429.72
16 [1.2060670852661133, 0.5876715183258057, 0.301... [0.5213333368301392, 0.7960087060928345, 0.901... [0.7380061745643616, 0.6602489352226257, 0.777... [0.722196102142334, 0.7658039331436157, 0.7722... [42.88776779174805, 84.30938196182251, 125.639... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 RAW 200 10000 True [[1499 249 353 35 77]\n [ 109 1561 244 ... NORMAL - 5 627.37
17 [1.1218571662902832, 0.5309804677963257, 0.279... [0.5644392371177673, 0.8198866844177246, 0.911... [0.6812100410461426, 0.6584259867668152, 0.761... [0.7573333382606506, 0.7799215912818909, 0.788... [42.935046672821045, 84.6618320941925, 126.400... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 LOWER 200 10000 True [[1526 250 228 57 152]\n [ 111 1533 188 ... NORMAL - 5 630.26
18 [1.2983185052871704, 0.67186439037323, 0.35850... [0.4403764605522156, 0.7503268122673035, 0.875... [0.8450406789779663, 0.6062559485435486, 0.653... [0.6599215865135193, 0.7830588221549988, 0.803... [41.16602444648743, 81.34677410125732, 120.331... 0 DENSE 64 Adam 0.001 - 300 - 0.79 0 0 DEFAULT 200 10000 True [[1461 291 227 76 158]\n [ 174 1726 94 ... NORMAL - 5 604.90
19 [1.1131271123886108, 0.4970025420188904, 0.263... [0.5663529634475708, 0.8301350474357605, 0.912... [0.634172260761261, 0.5626846551895142, 0.6592... [0.7694117426872253, 0.8014117479324341, 0.809... [42.32026171684265, 83.03080439567566, 123.852... 0 DENSE 64 Adam 0.001 - 300 - 0.80 0 0 LOWER_I 200 10000 True [[1508 290 287 32 96]\n [ 77 1770 190 ... NORMAL - 5 618.58
20 [1.4313589334487915, 0.976482629776001, 0.5856... [0.3718431293964386, 0.5906579494476318, 0.785... [1.1860140562057495, 0.7914318442344666, 0.690... [0.492078423500061, 0.7022745013237, 0.7548235... [75.83154845237732, 148.4723603725433, 221.875... 0 DENSE 64 Adam 0.001 - 300 - 0.76 0 0 RAW 400 10000 True [[1521 245 329 36 82]\n [ 127 1526 201 ... NORMAL - 5 1535.43
21 [1.2526357173919678, 0.7180959582328796, 0.415... [0.4922509789466858, 0.7275816798210144, 0.859... [0.8532054424285889, 0.7139995694160461, 0.696... [0.6709019541740417, 0.7447842955589294, 0.779... [72.2411801815033, 143.07429265975952, 214.991... 0 DENSE 64 Adam 0.001 - 300 - 0.77 0 0 LOWER 400 10000 True [[1541 148 435 16 73]\n [ 74 1643 317 ... NORMAL - 5 1503.40
22 [1.612979292869568, 1.6098767518997192, 1.6102... [0.2574901878833771, 0.2004705816507339, 0.200... [1.6093101501464844, 1.6092772483825684, 1.609... [0.1923137307167053, 0.199215680360794, 0.1920... [71.83844518661499, 143.0635724067688, 215.440... 0 DENSE 64 Adam 0.001 - 300 - 0.21 0 0 DEFAULT 400 10000 True [[ 3 0 2210 0 0]\n [ 3 0 2200 ... NORMAL - 5 1079.05
23 [1.6123055219650269, 1.4136512279510498, 0.892... [0.1997019648551941, 0.339607834815979, 0.6405... [1.6083794832229614, 1.10409414768219, 0.74592... [0.1921568661928177, 0.5328627228736877, 0.719... [72.79997491836548, 144.0263090133667, 215.028... 0 DENSE 64 Adam 0.001 - 300 - 0.79 0 0 LOWER_I 400 10000 True [[1617 342 145 20 89]\n [ 121 1746 103 ... NORMAL - 5 2015.48
24 [2.010152578353882, 1.5290956497192385, 1.2335... [0.3290140032768249, 0.4641103744506836, 0.562... [1.6453301906585691, 1.5359132289886477, 1.550... [0.4279738664627075, 0.4795294106006622, 0.506... [122.05491280555724, 241.92626881599423, 364.0... 0 DENSE 64 Adam 0.001 - 300 - 0.49 0 0 RAW 200 10000 True [[1071 369 36 61 34 64 271 114 7... NORMAL - 15 1822.20
25 [2.0807976722717285, 1.5866154432296753, 1.302... [0.3101803958415985, 0.4424575269222259, 0.535... [1.694566249847412, 1.5601444244384766, 1.5679... [0.4060130715370178, 0.4670849740505218, 0.481... [122.93492102622986, 244.07795214653012, 366.3... 0 DENSE 64 Adam 0.001 - 300 - 0.47 0 0 LOWER 200 10000 True [[1157 287 34 93 38 72 240 96 18... NORMAL - 15 1827.86
26 [2.22964859008789, 1.658429503440857, 1.336077... [0.257966011762619, 0.4204676747322082, 0.5402... [1.8096245527267456, 1.5306107997894287, 1.440... [0.3592156767845154, 0.4859085083007812, 0.534... [124.48556709289552, 248.0258502960205, 368.81... 0 DENSE 64 Adam 0.001 - 300 - 0.54 0 0 DEFAULT 200 10000 True [[1180 105 23 97 50 109 413 26 7... NORMAL - 15 2570.65
27 [1.9838707447052, 1.4141194820404053, 1.121265... [0.3484601378440857, 0.5100711584091187, 0.607... [1.531771898269653, 1.3959068059921265, 1.3805... [0.4722614288330078, 0.5345359444618225, 0.558... [122.99041819572447, 245.43910098075867, 367.5... 0 DENSE 64 Adam 0.001 - 300 - 0.56 0 0 LOWER_I 200 10000 True [[1257 309 28 8 77 50 398 8 3... NORMAL - 15 2569.64
28 [2.298656940460205, 1.906553149223328, 1.59620... [0.2510797381401062, 0.3431517779827118, 0.438... [2.042023181915283, 1.7355191707611084, 1.6539... [0.3018562197685241, 0.4036078453063965, 0.439... [219.676766872406, 433.9881939888001, 650.6269... 0 DENSE 64 Adam 0.001 - 300 - 0.44 0 0 RAW 400 10000 True [[1050 246 42 204 46 94 260 77 4... NORMAL - 15 4540.63
29 [2.215775489807129, 1.8288313150405884, 1.5736... [0.2619398832321167, 0.3532491028308868, 0.434... [1.9354660511016848, 1.716828465461731, 1.6661... [0.327320247888565, 0.3944052159786224, 0.4274... [212.37938380241397, 417.8751857280731, 627.87... 0 DENSE 64 Adam 0.001 - 300 - 0.47 0 0 LOWER 400 10000 True [[1142 226 8 64 130 116 306 88 8... NORMAL - 15 5882.86
30 [2.7099225521087646, 2.708319902420044, 2.7084... [0.1086745113134384, 0.0673812627792358, 0.066... [2.7085378170013428, 2.708480834960937, 2.7085... [0.0658823549747467, 0.0657777786254882, 0.064... [213.5495536327362, 428.6686849594116, 646.272... 0 DENSE 64 Adam 0.001 - 300 - 0.43 0 0 DEFAULT 400 10000 True [[ 103 125 11 296 155 802 447 58 7... NORMAL - 15 11962.07
In [269]:
px.bar(
    dense_df,
    x='Key',
    y="Accuracy",
    color='EmbeddingSize',
    barmode="group",
    facet_col="SeqLen",
    facet_row="NumberOfAuthors",
    text=dense_df.Accuracy,
    title="Výsledky klasifikace autorů v závislosti na proměnných"
    
)

Jak lze pozorovat na výsledcích bylo experimentováno s velikostí vektoru, kterým bude reprezentováno slovo. Přičemž rozdíl mezi velikosti 50 a 300 nebyl markantní. Výsledky byly na všech typech předzpracovaných datech podobné. Můžeme s určitou jistotou říct, že pro tuto neuronovou síť stačí takhle velký číselný vektor, aby zachytil většinu relevantních vlastností slova. Velikost nad 50 nepřispěje k zlepšení přesnosti, akorát zvýší složitost neuronové sítě, neboli rozšíří počet parametrů, které bude potřeba naučit interně v neuronové síti, aby kvalitně bylo schopna předpovědět daného autora. U 15 autorů, pak lze pozorovat, že menší embedding hrál určitou roli v přesnosti.

Velikost vstupní sekvence jednou byla omezena na průměrnou velikost délky záznamu a v druhém případě na maximálně, takže žádná nebyla oříznuta. Spíše docházelo k operaci paddingu. U 15 autorů lze pozorovat jisté zlepšení u kratší sekvence.

Pro tento jednoduchý model lze pozorovat, že nejlepších výsledků dosahovalo předzpracování z knihovny gensim a slabší předzpracování, které transformovalo do malých písmen, odstranilo interpunkci, minimalizovalo počet bílých znaků. Důvodem může být, že tento model lépe pracoval právě s klíčovými slovy, která jsou maximálně relevantní. Šum bohužel zmenšoval přesnost modelu o jednotky procent.

Zároveň lze pozorovat, že 15 autorů pro síť bylo komplikované rozeznat. Přesnost kolem 60 % není nijak extra zajímavá. Narozdíl s 5 autory si síť dokázala poradit s přesností skoro 80 %.

Rekurentní neuronová síť¶

In [270]:
rnn_df = df[df.ModelName == "Bidirectional GRU"] 
In [271]:
rnn_df
Out[271]:
loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors CalculationTime
31 [6687085056.0, 1187478372352.0, 1.320836067199... [0.3819360733032226, 0.4748932421207428, 0.597... [nan, 1.0665013790130615, 0.9180729389190674, ... [0.294117659330368, 0.5433725714683533, 0.6208... [1092.711680173874, 2239.074378967285, 3320.77... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.69 0 0 RAW 300 10000 True [[1428 125 512 34 114]\n [ 50 1417 393 ... NORMAL - 5 62433.80
32 [2406436.75, 25465.07421875, nan, nan, nan] [0.3964235186576843, 0.5264662504196167, 0.383... [1.1757404804229736, 1.1074206829071045, nan, ... [0.485647052526474, 0.5080784559249878, 0.1921... [1368.7052409648895, 2801.06485581398, 4307.10... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.51 0 0 LOWER 300 10000 True [[1731 434 4 3 41]\n [ 601 1472 94 ... NORMAL - 5 21699.77
33 [1.3197650909423828, 183.88555908203125, 0.674... [0.3472784459590912, 0.5911459922790527, 0.748... [1.1397396326065063, 0.7558140158653259, 0.713... [0.4633725583553314, 0.7105882167816162, 0.735... [1794.731845855713, 3350.3680169582367, 4541.5... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.74 0 0 DEFAULT 300 10000 True [[1087 167 730 68 161]\n [ 17 1742 242 ... NORMAL - 5 30416.04
34 [278055.28125, 1.4219011068344116, 31504.09765... [0.3278745114803314, 0.4618736505508423, 0.516... [1.2132927179336548, 1.139415979385376, 1.1058... [0.4410980343818664, 0.4923921525478363, 0.526... [1202.219447374344, 2394.1133301258087, 3581.0... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.75 0 0 LOWER_I 300 10000 True [[1397 428 212 64 112]\n [ 143 1533 396 ... NORMAL - 5 65487.81
35 [nan, nan, nan] [0.0903417393565177, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [3581.0230338573456, 7243.392804384232, 13244.... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 RAW 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15 24069.29
36 [nan, nan, nan] [0.0666928067803382, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [5272.877393007278, 10100.672457456589, 14649.... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 LOWER 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15 30023.22
37 [143974.984375, 1.981084942817688, nan, nan, nan] [0.1826091557741165, 0.2948496639728546, 0.301... [2.123115539550781, 1.8301618099212649, nan, n... [0.2459607869386673, 0.3584313690662384, 0.068... [4745.052874326706, 9517.231180667875, 14253.2... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.36 0 0 DEFAULT 300 10000 True [[ 552 57 29 485 1 132 234 68 4... NORMAL - 15 70875.74
38 [nan, nan, nan] [0.0665464028716087, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [4647.039727687836, 9398.7957239151, 14184.882... 0 Bidirectional GRU 64 Adam 0.001 - 50 - 0.07 0 0 LOWER_I 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15 28230.72
39 [nan, nan, nan] [0.1690849661827087, 0.2015163451433181, 0.201... [nan, nan, nan] [0.1921568661928177, 0.1921568661928177, 0.192... [1732.056715965271, 3492.265405893326, 5230.93... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.20 0 0 RAW 300 10000 True [[2213 0 0 0 0]\n [2203 0 0 ... NORMAL - 5 10455.26
40 [nan, nan, nan] [0.3445647060871124, 0.2015163451433181, 0.201... [nan, nan, nan] [0.1921568661928177, 0.1921568661928177, 0.192... [1745.630146026611, 3480.498616695404, 4916.81... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.20 0 0 LOWER 300 10000 True [[2213 0 0 0 0]\n [2203 0 0 ... NORMAL - 5 10142.94
41 [9.50131130218506, 0.7876715660095215, 1.23750... [0.3992627561092376, 0.6950762271881104, 0.808... [0.9494749307632446, 0.6531278491020203, 0.620... [0.6191372275352478, 0.7658039331436157, 0.788... [1764.6829175949097, 3511.476901292801, 5241.2... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.80 0 0 DEFAULT 300 10000 True [[1597 259 167 86 104]\n [ 151 1747 64 ... NORMAL - 5 48701.47
42 [117312464.0, nan, nan, nan] [0.4742588102817535, 0.4018823504447937, 0.201... [1.0666706562042236, nan, nan, nan] [0.5543529391288757, 0.1921568661928177, 0.192... [1787.1415581703186, 3542.0427582263947, 5307.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.57 0 0 LOWER_I 300 10000 True [[1077 131 929 3 73]\n [ 132 1246 712 ... NORMAL - 5 17682.81
43 [nan, nan, nan] [0.1155686303973198, 0.0660798847675323, 0.066... [nan, nan, nan] [0.0684967339038848, 0.0684967339038848, 0.068... [5241.320497512817, 10400.483037471771, 15546.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.07 0 0 RAW 300 10000 True [[2316 0 0 0 0 0 0 0 0... NORMAL - 15 31188.12
44 [4.139690399169922, 31051933696.0, 5946009.0, ... [0.2126849740743637, 0.3071604967117309, 0.362... [45.60334777832031, 1.8750523328781128, 1.7882... [0.2989281117916107, 0.3549803793430328, 0.382... [4035.43758225441, 7763.776393175125, 11458.88... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.56 0 0 LOWER 300 10000 True [[1488 176 59 4 126 19 158 42 5... NORMAL - 15 207936.63
45 [3889620.75, 5.671187400817871, 174.6585388183... [0.2653124034404754, 0.3765461146831512, 0.458... [74537.96875, 1.6551882028579712, 3.5518321990... [0.3310849666595459, 0.4476862847805023, 0.468... [3747.2181475162506, 7466.645300388336, 11197.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.57 0 0 DEFAULT 300 10000 True [[1448 185 41 29 71 61 236 37 5... NORMAL - 15 196815.26
46 [47510450176.0, 10461987.0, 3.044379234313965,... [0.2487633973360061, 0.3066957294940948, 0.376... [2.28581562481608e+19, 1.265042553581863e+18, ... [0.3026405274868011, 0.3697254955768585, 0.433... [3202.14945268631, 6400.369203567505, 9596.970... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.43 0 0 LOWER_I 300 10000 True [[ 918 165 76 21 2 162 512 56 9... NORMAL - 15 67171.45
In [272]:
px.bar(
    rnn_df,
    x='Key',
    y="Accuracy",
    color='EmbeddingSize',
    barmode="group",
    facet_col="SeqLen",
    facet_row="NumberOfAuthors",
    text=rnn_df.Accuracy,
    title="Výsledky klasifikace autorů u RNN v závislosti na proměnných"
)

Jak lze pozorovat v mnoha případech síť nedokázala konvergovat k dobrým parametrům, aby mohla rozeznat autora. Můžeme se zde zaměřit pouze na výsledky s velikostí 50 a předzpracování pomocí gensim knihovny. Výsledná přesnost i přes takto složitější model nebyla poznatelně lepší. Bylo by potřeba v rámci projektu provést větší exploraci nad architekturou, aby bylo zjištěno, proč síť nemá tendenci správně konvergovat.

Transformery¶

In [273]:
transformer_df = df[df.ModelName == "Transformer"] 
In [274]:
transformer_df
Out[274]:
loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors CalculationTime
47 [1.721323847770691, 1.6283942461013794, 1.6182... [0.167633980512619, 0.200139433145523, 0.20036... [1.6598328351974487, 1.6358952522277832, 1.622... [0.1987451016902923, 0.1921568661928177, 0.192... [12731.715883731842, 25516.930138111115, 38309... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[2213 0 0 0 0]\n [2203 0 0 ... TL distilbert-base-uncased 5 76558.45
48 [0.6788761615753174, 0.2942142188549042, 0.161... [0.6950744986534119, 0.8956339955329895, 0.945... [0.4689745604991913, 0.3302323520183563, 0.374... [0.8304314017295837, 0.8854901790618896, 0.882... [12792.493296146393, 25582.148589611053, 38391... 0 Transformer 128 Adam 5e-05 - - - 0.88 0 0 LOWER_I 300 - True [[1932 65 99 57 60]\n [ 96 1780 103 ... TL distilbert-base-uncased 5 76765.80
49 [1.720706582069397, 1.62993323802948, 1.619087... [0.2695686221122741, 0.1993899792432785, 0.199... [1.612503170967102, 1.6107159852981567, 1.6184... [0.2062745094299316, 0.2043921500444412, 0.198... [25539.273556947708, 51105.00569915772, 76684.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL bert-base-uncased 5 153328.81
50 [0.6129876375198364, 0.2475549280643463, 0.137... [0.7192941308021545, 0.9139520525932312, 0.954... [0.3863182961940765, 0.3083285987377167, 0.356... [0.8680784106254578, 0.8953725695610046, 0.892... [25577.77806091309, 51058.13135123253, 76546.3... 0 Transformer 128 Adam 5e-05 - - - 0.89 0 0 LOWER_I 300 - True [[1845 72 102 118 76]\n [ 34 1847 67 ... TL bert-base-uncased 5 153182.25
51 [1.7034351825714111, 1.6270469427108765, 1.617... [0.2704313695430755, 0.198047935962677, 0.1987... [1.6337881088256836, 1.6143453121185305, 1.611... [0.2062745094299316, 0.2043921500444412, 0.198... [8417.139698982239, 16806.370790719986, 25206.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL google/electra-small-discriminator 5 50429.72
52 [1.167790174484253, 0.668626070022583, 0.48968... [0.5210353136062622, 0.760610044002533, 0.8271... [0.8901610970497131, 0.8246893286705017, 0.576... [0.6842352747917175, 0.7243921756744385, 0.799... [8379.385590314865, 16772.64554834366, 25164.1... 0 Transformer 128 Adam 5e-05 - - - 0.81 0 0 LOWER_I 300 - True [[1545 117 274 131 146]\n [ 58 1393 243 ... TL google/electra-small-discriminator 5 50316.22
In [275]:
px.bar(
    transformer_df,
    x='Key',
    y="Accuracy",
    color='TransformerName',
    barmode="group",
    facet_col="LR",
    facet_row="NumberOfAuthors",
    text=transformer_df.Accuracy,
    title="Výsledky klasifikace autorů u Transformeru v závilosti na proměnných"
)

V experimentech bylo provedeno pouze pár běhů, jelikož transformery při tunění mají tendenci běžet delší dobu z důvodu široké interní reprezentace.

V první řadě byly provedeny běhy s učící konstantou 0.001, která je defaultně nastavena u Adam optimizeru. Vyšlo najevo, že takhle vysoká hodnota pro Transformery není vhod. Síť při tak velké učící konstantně nemá schopnost naučit se správnou interní reprezentaci. Další hodnota byla snížena na mnohem menší číslo, přesně 5e-05. Po této změně již všechny typy transformeru konvergovaly ke kvalitnímu výsledku.

Experimenty byly provedeny na 3 typech transformeru, a to:

  • DistilBert
  • Bert
  • Electra

Účelem bylo porovnání běhu těchto modelů a zároveň zhodnotit jejich přesnost predikce.

Jak lze tedy z výsledku pozorovat nejlépe dopadl nejkomplexnější a nejstarší model Bert, který překonal "jednoduché neuronové sítě" skoro o 10 procent u 5 autorů. Podobný posun lze možná předpokládat i u více autorů. Hnedka za Bert typem se dostala zjednodušená verze DistilBertu s přesností 88 %. Electra pak dosahovala slabších výsledků 81 %.

In [276]:
px.bar(
    transformer_df,
    x='Key',
    y="CalculationTime",
    color='TransformerName',
    barmode="group",
    facet_col="LR",
    facet_row="NumberOfAuthors",
    text=transformer_df.CalculationTime,
    title="Výsledky klasifikace autorů u Transformeru v závilosti na proměnných"
)

Časová složitost je odstupňována, dle složitost modelu:

  • Bert - 3/3
  • Distibert - 1/2
  • Electra - 1/3

Po zhodnocení dostupných dat bychom se pravděpodobně vybrali model DistilBert, běžel vcelku rychle, dosahoval stejné přesnosti jako úplný model Bert. Důležité je myslet na nízkou učící konstantu.

In [277]:
transformer_df
Out[277]:
loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors CalculationTime
47 [1.721323847770691, 1.6283942461013794, 1.6182... [0.167633980512619, 0.200139433145523, 0.20036... [1.6598328351974487, 1.6358952522277832, 1.622... [0.1987451016902923, 0.1921568661928177, 0.192... [12731.715883731842, 25516.930138111115, 38309... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[2213 0 0 0 0]\n [2203 0 0 ... TL distilbert-base-uncased 5 76558.45
48 [0.6788761615753174, 0.2942142188549042, 0.161... [0.6950744986534119, 0.8956339955329895, 0.945... [0.4689745604991913, 0.3302323520183563, 0.374... [0.8304314017295837, 0.8854901790618896, 0.882... [12792.493296146393, 25582.148589611053, 38391... 0 Transformer 128 Adam 5e-05 - - - 0.88 0 0 LOWER_I 300 - True [[1932 65 99 57 60]\n [ 96 1780 103 ... TL distilbert-base-uncased 5 76765.80
49 [1.720706582069397, 1.62993323802948, 1.619087... [0.2695686221122741, 0.1993899792432785, 0.199... [1.612503170967102, 1.6107159852981567, 1.6184... [0.2062745094299316, 0.2043921500444412, 0.198... [25539.273556947708, 51105.00569915772, 76684.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL bert-base-uncased 5 153328.81
50 [0.6129876375198364, 0.2475549280643463, 0.137... [0.7192941308021545, 0.9139520525932312, 0.954... [0.3863182961940765, 0.3083285987377167, 0.356... [0.8680784106254578, 0.8953725695610046, 0.892... [25577.77806091309, 51058.13135123253, 76546.3... 0 Transformer 128 Adam 5e-05 - - - 0.89 0 0 LOWER_I 300 - True [[1845 72 102 118 76]\n [ 34 1847 67 ... TL bert-base-uncased 5 153182.25
51 [1.7034351825714111, 1.6270469427108765, 1.617... [0.2704313695430755, 0.198047935962677, 0.1987... [1.6337881088256836, 1.6143453121185305, 1.611... [0.2062745094299316, 0.2043921500444412, 0.198... [8417.139698982239, 16806.370790719986, 25206.... 0 Transformer 128 Adam 0.001 - - - 0.20 0 0 LOWER_I 300 - True [[ 0 0 0 0 2213]\n [ 0 0 0 ... TL google/electra-small-discriminator 5 50429.72
52 [1.167790174484253, 0.668626070022583, 0.48968... [0.5210353136062622, 0.760610044002533, 0.8271... [0.8901610970497131, 0.8246893286705017, 0.576... [0.6842352747917175, 0.7243921756744385, 0.799... [8379.385590314865, 16772.64554834366, 25164.1... 0 Transformer 128 Adam 5e-05 - - - 0.81 0 0 LOWER_I 300 - True [[1545 117 274 131 146]\n [ 58 1393 243 ... TL google/electra-small-discriminator 5 50316.22
In [278]:
def create_dfs(selector):
    dfs = []

    for i in range(len(transformer_df)):
        row = transformer_df.iloc[i, :]
        column = row[selector]
        name = row.TransformerName + " " + row.LR
        current_df = pd.DataFrame()
        current_df['Type'] = len(column) * [name]
        current_df['Value'] = column

        dfs.append(current_df)
        
    return dfs
In [279]:
df_val_loss = pd.concat(create_dfs('val_loss'))
In [280]:
df_val_acc = pd.concat(create_dfs('accuracy'))
In [281]:
px.line(df_val_loss, y='Value', color='Type', title="Chyba v transformerech")
In [282]:
px.line(df_val_acc, y='Value', color='Type', title="Přesnost v transformerech")

Na grafech lze vidět zmíněnou stagnaci modelu při vysoké učící konstantě a zároveň fakt, že Electra by možná mohla dosáhnout lepších výsledků, pakliže by mohlo dále běžet učení model. Výsledkem by mohlo být překonání Berta.

In [283]:
bert_confusion_df = transformer_df[transformer_df.TransformerName == "distilbert-base-uncased"]
bert_confusion_df = bert_confusion_df[bert_confusion_df.LR == "5e-05"]
bert_confusion = bert_confusion_df.ConfMatrix
In [284]:
def parse_confusion(conf):
    conf = conf[1:len(conf)-1]
    splited = conf.split('\n')
    matrix = []
    for row in splited:
        row = row.strip()
        row = row[1:len(row)-1]
        row = strip_multiple_whitespaces_gensim(row).strip()
        numbers = [int(n) for n in row.split(' ')]
        matrix.append(numbers)
    return np.array(matrix)
In [285]:
conf = parse_confusion(bert_confusion.values[0])
In [286]:
plt.figure(figsize=[20, 20]) 
sns.heatmap(conf, annot=True, fmt='g', cmap='Blues')
Out[286]:
<AxesSubplot:>

Z vizualizované matice záměn lze pozorovat, že často dochází k záměně u 1 proti 2 a 3.

Výběr nejlepších¶

In [287]:
df_5 = df[df.NumberOfAuthors == "5"].reset_index()
In [288]:
df_5 = df_5.sort_values(by='Accuracy', ascending=False)
In [289]:
df_5['ConstructedModelName'] = df_5.ModelName + df_5.TransformerName  
df_5 = df_5.iloc[0:5, :]
In [290]:
px.bar(
    df_5,
    x='ConstructedModelName',
    y="Accuracy",
    text=df_5.Accuracy,
    title="Nejlepších 5 modelů u klasifikace 5 autorů",
    color='ConstructedModelName'
)

Na grafu lze vidět, že jednoduchá hluboká neuronová síť dosáhla stejného výsledku jako Electra. Bert s 100m parametrů, pak zcela kraluje všem ostatním modelům.

In [291]:
px.bar(
    df_5,
    x='ConstructedModelName',
    y="Accuracy",
    text=df_5.Accuracy,
    title="Nejlepších 5 modelů u klasifikace 5 autorů",
    color='ConstructedModelName'
)
In [292]:
px.bar(
    df_5,
    x='ConstructedModelName',
    y="CalculationTime",
    text=df_5.CalculationTime,
    title="Časová náročnost 5 modelů u klasifikace 5 autorů",
    color='ConstructedModelName'
)

Je nutné pozorovat, že vyšší přesnost je za cenu mnohem vyšší časové složitosti. V tomto případě by šlo aplikovat paralelizaci na transformer modely, ale i přesto náročnost na výpočet lze pozorovat mnohem vyšší. Proto je nutné vždy zvážit, zda chceme benefitovat ze zvýšené přesnosti i na úkor tak vysoké složitosti.

In [293]:
df_15 = df[df.NumberOfAuthors == "15"].reset_index()
In [294]:
df_15 = df_15.sort_values(by='Accuracy', ascending=False)
In [295]:
df_15['ConstructedModelName'] = df_15.ModelName + df_15.TransformerName + [str(x) for x in df_15.index]
df_15 = df_15.iloc[0:5, :] 
In [296]:
df_15
Out[296]:
index loss accuracy val_loss val_accuracy time NaN ModelName BatchSize Optimizer LR Epochs EmbeddingSize Time Accuracy Hits Miss Key SeqLen VocabSize TrainableEmbedding ConfMatrix Type TransformerName NumberOfAuthors CalculationTime ConstructedModelName
2 10 [2.0945076942443848, 1.5407600402832031, 1.260... [0.3058666586875915, 0.4837879538536072, 0.582... [1.6695539951324463, 1.3899223804473877, 1.346... [0.4403137266635895, 0.5398169755935669, 0.576... [55.19337034225464, 109.7722339630127, 164.612... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 DEFAULT 200 10000 True [[1521 213 24 13 44 56 146 17 9... NORMAL - 15 1143.94 DENSE-2
3 11 [2.0426511764526367, 1.4889464378356934, 1.221... [0.331916332244873, 0.4952156841754913, 0.5902... [1.6068023443222046, 1.3842854499816897, 1.342... [0.4554771184921264, 0.5311895608901978, 0.570... [57.66135025024414, 114.47401738166808, 171.94... 0 DENSE 64 Adam 0.001 - 50 - 0.58 0 0 LOWER_I 200 10000 True [[1535 153 14 8 31 47 290 53 11... NORMAL - 15 1197.03 DENSE-3
21 45 [3889620.75, 5.671187400817871, 174.6585388183... [0.2653124034404754, 0.3765461146831512, 0.458... [74537.96875, 1.6551882028579712, 3.5518321990... [0.3310849666595459, 0.4476862847805023, 0.468... [3747.2181475162506, 7466.645300388336, 11197.... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.57 0 0 DEFAULT 300 10000 True [[1448 185 41 29 71 61 236 37 5... NORMAL - 15 196815.26 Bidirectional GRU-21
11 27 [1.9838707447052, 1.4141194820404053, 1.121265... [0.3484601378440857, 0.5100711584091187, 0.607... [1.531771898269653, 1.3959068059921265, 1.3805... [0.4722614288330078, 0.5345359444618225, 0.558... [122.99041819572447, 245.43910098075867, 367.5... 0 DENSE 64 Adam 0.001 - 300 - 0.56 0 0 LOWER_I 200 10000 True [[1257 309 28 8 77 50 398 8 3... NORMAL - 15 2569.64 DENSE-11
20 44 [4.139690399169922, 31051933696.0, 5946009.0, ... [0.2126849740743637, 0.3071604967117309, 0.362... [45.60334777832031, 1.8750523328781128, 1.7882... [0.2989281117916107, 0.3549803793430328, 0.382... [4035.43758225441, 7763.776393175125, 11458.88... 0 Bidirectional GRU 64 Adam 0.001 - 300 - 0.56 0 0 LOWER 300 10000 True [[1488 176 59 4 126 19 158 42 5... NORMAL - 15 207936.63 Bidirectional GRU-20
In [297]:
px.bar(
    df_15,
    x='ConstructedModelName',
    y="Accuracy",
    text=df_15.Accuracy,
    title="Nejlepších 5 modelů u klasifikace 15 autorů",
    color='ConstructedModelName'
)

Rozdíl proti 5 autorům je dost poznatelný, jedná se o 30 %. Nutné je podotknout, že v těchto případech model má větší náchylnost na zmatení stylem ostatních autorů.

In [298]:
px.bar(
    df_15,
    x='ConstructedModelName',
    y="CalculationTime",
    text=df_15.CalculationTime,
    title="Časová náročnost 5 modelů u klasifikace 15 autorů",
    color='ConstructedModelName'
)

Časová náročnost je násobně vyšší u oboustranné RNN sítě. Důvodem je, že velikost vstupní sekvence je 200 a 300. Zpracování probíhá slovo po slovu, a proto trvání je takové.

Závěr¶

  • Hluboká neuronová síť dosahuje celkem vysokých přesností, i když se jedná o tak jednoduchý model. Embedding vrstva odvádí skvělou práci, kdy je schopná se velice efektivně i přes jednoduchý model naučit číselnou reprezentaci pro každé slovo.

  • RNN nedosahovala dobrých výsledků, bylo by potřeba provést více experimentální činnosti.

  • Transformer model má vysokou časovou náročnost a je nutné specifikovat nízkou učící konstantu.

  • Transformer model dosáhl nejlepších výsledků u 5 autorů a to 89 procent.

  • Transformer model stačí malé množství epoch k tomu, aby byl schopný vyřešit náš definovaný problém.

  • Electra model umožňuje delší učení.